Spaces:
Running
on
Zero
Running
on
Zero
| import json | |
| import os | |
| from collections import defaultdict | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| import gradio as gr | |
| import requests | |
| import spaces | |
| import torch | |
| import yaml | |
| from gradio_rangeslider import RangeSlider | |
| from guidance import json as gen_json | |
| from guidance.models import Transformers | |
| from transformers import AutoTokenizer, GPT2LMHeadModel, set_seed | |
| from schema import GDCCohortSchema # isort: skip | |
| from scheduler import ParquetScheduler # isort: skip | |
| EXAMPLE_INPUTS = [ | |
| "bam files for TCGA-BRCA", | |
| "kidney or adrenal gland cancers with alcohol history", | |
| "tumor samples from male patients with acute myeloid lymphoma", | |
| ] | |
| GDC_CASES_API_ENDPOINT = "https://api.gdc.cancer.gov/cases" | |
| MODEL_NAME = "uc-ctds/gdc-cohort-llm-gpt2-s1M" | |
| TOKENIZER_NAME = MODEL_NAME | |
| MODEL_READ_TOKEN = os.environ.get("MODEL_READ_TOKEN", None) | |
| DATASET_WRITE_TOKEN = os.environ.get("DATASET_WRITE_TOKEN", None) | |
| with open("config.yaml", "r") as f: | |
| CONFIG = yaml.safe_load(f) | |
| TAB_NAMES = [tab["name"] for tab in CONFIG["tabs"]] | |
| CARD_NAMES = [card["name"] for tab in CONFIG["tabs"] for card in tab["cards"]] | |
| CARD_FIELDS = [card["field"] for tab in CONFIG["tabs"] for card in tab["cards"]] | |
| CARD_2_FIELD = dict(list(zip(CARD_NAMES, CARD_FIELDS))) | |
| FIELD_2_CARD = dict(list(zip(CARD_FIELDS, CARD_NAMES))) | |
| CARD_2_VALUES = { | |
| card["name"]: card["values"] for tab in CONFIG["tabs"] for card in tab["cards"] | |
| } | |
| FACETS_STR = ",".join( | |
| [ | |
| f.replace("cases.", "") | |
| for f, n in zip(CARD_FIELDS, CARD_NAMES) | |
| if not isinstance(CARD_2_VALUES[n], dict) | |
| # ^ skip range facets in bin counts | |
| ] | |
| ) | |
| PREF_DS = os.environ.get("PREF_DS", False) | |
| if PREF_DS: | |
| assert DATASET_WRITE_TOKEN is not None | |
| scheduler = ParquetScheduler( | |
| repo_id=PREF_DS, | |
| token=DATASET_WRITE_TOKEN, | |
| schema={ | |
| "prompt": {"_type": "Value", "dtype": "string"}, | |
| "cohort_filter": {"_type": "Value", "dtype": "string"}, | |
| "preference": {"_type": "Value", "dtype": "bool"}, | |
| "timestamp": {"_type": "Value", "dtype": "string"}, | |
| }, | |
| ) | |
| tok = AutoTokenizer.from_pretrained(TOKENIZER_NAME, token=MODEL_READ_TOKEN) | |
| model = GPT2LMHeadModel.from_pretrained(MODEL_NAME, token=MODEL_READ_TOKEN) | |
| model = model.to("cuda" if torch.cuda.is_available() else "cpu") | |
| model = model.eval() | |
| # Generate cohort filter JSON from free text | |
| def generate_filter(query: str) -> str: | |
| """ | |
| Converts a free text description of a cancer cohort into a GDC structured cohort filter. | |
| Args: | |
| query (str): The free text cohort description | |
| Returns: | |
| str: JSON structured GDC cohort filter | |
| """ | |
| set_seed(42) | |
| lm = Transformers( | |
| model=model, | |
| tokenizer=tok, | |
| # sampling_params=SamplingParams, | |
| ) | |
| lm += query | |
| lm += gen_json( | |
| name="cohort", schema=GDCCohortSchema, temperature=0, max_tokens=1024 | |
| ) | |
| cohort_filter = lm["cohort"] | |
| cohort_filter = json.dumps(json.loads(cohort_filter), indent=4) | |
| return cohort_filter | |
| def _prepare_value_count(value: str, count: int) -> str: | |
| return f"{value} [{count}]" | |
| def _get_base_value(value_count: str) -> str: | |
| value = value_count | |
| if " [" in value: | |
| value = value[: value.rfind(" [")] | |
| return value | |
| def _patch_range_filters_for_facet_endpoint(cohort_filter: str) -> str: | |
| # patch for range selectors which use nested `and` | |
| # seems `facets` and nested `and` don't play well together | |
| # so flatten direct nested `and` for query execution only | |
| # this is equivalent since our top-level is always `and` | |
| # keeping nested `and` for presentation and model generations though | |
| temp = json.loads(cohort_filter) | |
| ops = temp["content"] | |
| new_ops = [] | |
| for op in ops: | |
| # assumes no deeper than single level nesting | |
| if op["op"] == "and": | |
| for subop in op["content"]: | |
| new_ops.append(subop) | |
| else: | |
| new_ops.append(op) | |
| temp["content"] = new_ops | |
| return json.dumps(temp) | |
| def _convert_cohort_filter_to_lookup(cohort_filter: str) -> dict[str, int | list[str]]: | |
| # Pre-flatten nested ops for easier mapping in next step | |
| flattened_ops = [] | |
| for op in json.loads(cohort_filter)["content"]: | |
| # nested `and` can only be 1 deep based on schema | |
| if op["op"] == "and": | |
| flattened_ops.extend(op["content"]) | |
| else: | |
| flattened_ops.append(op) | |
| # Prepare and validate generated filters | |
| selected_field_2_values = dict() | |
| for op in flattened_ops: | |
| assert op["op"] in ["in", "=", "<", ">", "<=", ">="], f"Unknown handling for op: {op}" # fmt: skip | |
| content = op["content"] | |
| field, value = content["field"], content["value"] | |
| if op["op"] == "=": | |
| # convert = to <=,>= ops so it can be filled into card | |
| # use flattened_ops as a queue, defer current op | |
| flattened_ops.append( | |
| { | |
| "op": "<=", | |
| "content": content, | |
| } | |
| ) | |
| flattened_ops.append( | |
| { | |
| "op": ">=", | |
| "content": content, | |
| } | |
| ) | |
| continue # defer current op | |
| elif op["op"] == "<": | |
| # comparator values are ints so can convert to lte by sub 1 | |
| op["op"] = "<=" | |
| value -= 1 | |
| elif op["op"] == ">": | |
| # comparator values are ints so can convert to gte by add 1 | |
| op["op"] = ">=" | |
| value += 1 | |
| # comp ops will duplicate name, disambiguate by appending comp | |
| if op["op"] != "in": | |
| field += "_" + op["op"] | |
| # check that fields are not duplicated | |
| if field in selected_field_2_values: | |
| raise ValueError(f"{field} is ambiguously duplicated") | |
| selected_field_2_values[field] = value | |
| return selected_field_2_values | |
| def _convert_cohort_filter_to_active_selections(cohort_filter: str) -> list[str]: | |
| selected_field_2_values = _convert_cohort_filter_to_lookup(cohort_filter) | |
| active_choices = [] | |
| for field, values in selected_field_2_values.items(): | |
| card_name = FIELD_2_CARD[ | |
| field.replace("_<=", "").replace("_>=", "") # from lookup conversion | |
| ] | |
| default_values = CARD_2_VALUES[card_name] | |
| if isinstance(default_values, list): | |
| # checkbox | |
| possible_values = set(default_values) | |
| for value in values: | |
| if value not in possible_values: | |
| continue # model hallucination? | |
| active_choices.append(f"{card_name.upper()}: {value}") | |
| elif isinstance(default_values, dict): | |
| # range-slider, maybe other options in the future? | |
| assert default_values["type"] == "range", f"Expected range slider for card {card_name}" # fmt: skip | |
| assert isinstance(values, int), "values should be integer for range op" | |
| if ">=" in field: | |
| if values != default_values["min"]: | |
| active_choices.append(f"{card_name.upper()}: ≥{values}") | |
| elif "<=" in field: | |
| if values != default_values["max"]: | |
| active_choices.append(f"{card_name.upper()}: ≤{values}") | |
| else: | |
| raise ValueError(f"Unclear how field is not l/gte: {field}") | |
| else: | |
| raise ValueError(f"Unknown values for card {card_name}") | |
| return active_choices | |
| def _convert_cohort_filter_to_cards(cohort_filter: str, api_data: dict) -> list[dict]: | |
| # create lookup to use while iterating through filter card updates | |
| selected_field_2_values = _convert_cohort_filter_to_lookup(cohort_filter) | |
| # prepare card updates, use selected values to check boxes | |
| # values are given by the union of selected values and bucket counts | |
| # (some selected values may have 0 bucket counts) | |
| card_updates = [] | |
| for card_name, card_field in zip(CARD_NAMES, CARD_FIELDS): | |
| default_values = CARD_2_VALUES[card_name] | |
| if isinstance(default_values, list): | |
| # checkbox selector | |
| updated_choices = [] # the possible checkboxes | |
| updated_values = [] # the selected checkboxes | |
| other_choices = [] # separate out for sorting | |
| bucket_counts = api_data["aggregations"][card_field.replace("cases.", "")]["buckets"] # fmt: skip | |
| bucket_counts = {x["key"]: x["doc_count"] for x in bucket_counts} | |
| possible_values = set(default_values) | |
| # selected values go first as both values and choices | |
| if card_field in selected_field_2_values: | |
| unmatched_values = [] | |
| selected_values = selected_field_2_values.pop(card_field) | |
| for selected_value in selected_values: | |
| if selected_value not in possible_values: | |
| print( | |
| f"{card_field} value {selected_value} is not in the " | |
| "list of default values, is this a model hallucination?" | |
| ) | |
| unmatched_values.append(selected_value) | |
| continue # model hallucination? distinct from value with 0 count | |
| count = bucket_counts.pop(selected_value, 0) | |
| value_count = _prepare_value_count(selected_value, count) | |
| updated_choices.append(value_count) | |
| updated_values.append(value_count) | |
| if len(unmatched_values) != 0: | |
| # collect unmatched values back into selected_field_2_values | |
| # which may otherwise be tracking unmatched fields | |
| selected_field_2_values[card_field] = unmatched_values | |
| # fill in remaining possible values from bucket counts | |
| for other_choice, count in bucket_counts.items(): | |
| if other_choice not in possible_values: | |
| continue # schema mistmatch? ie if values are added | |
| other_choices.append(_prepare_value_count(other_choice, count)) | |
| update_obj = gr.update( | |
| choices=sorted(updated_choices) + sorted(other_choices), | |
| value=updated_values, # I think the order given here preserves selection order | |
| ) | |
| elif isinstance(default_values, dict): | |
| # range-slider, maybe other options in the future? | |
| # nothing to do with bucket counts for range slider | |
| assert ( | |
| default_values["type"] == "range" | |
| ), f"Expected range slider for card {card_name}" | |
| # Need to handle if model outputs flat range or nested range | |
| card_field_gte = card_field + "_>=" | |
| card_field_lte = card_field + "_<=" | |
| _min = default_values["min"] | |
| _max = default_values["max"] | |
| lo = selected_field_2_values.pop(card_field_gte, _min) | |
| hi = selected_field_2_values.pop(card_field_lte, _max) | |
| assert ( | |
| lo >= _min | |
| ), f"Generated lower bound ({lo}) less than minimum allowable value ({_min})" | |
| assert ( | |
| hi <= _max | |
| ), f"Generated upper bound ({hi}) greater than maximum allowable value ({_max})" | |
| update_obj = gr.update(value=(lo, hi)) | |
| else: | |
| raise ValueError(f"Unknown card type {card_name}") | |
| card_updates.append(update_obj) | |
| # selected_field_2_values may now have remaining, unmatched values | |
| # edit: updated json schema with enumerated fields should prevent unmatched fields | |
| if len(selected_field_2_values) != 0: | |
| print(f"Unmatched field/values in filter selections: {selected_field_2_values}") | |
| return card_updates | |
| def update_elements_from_filtered_api_call(cohort_filter: str) -> list[dict]: | |
| # return updates for: | |
| # - counter (text) | |
| # - active selections (checkbox group) | |
| # - upvote (enable button, reset text) | |
| # - downvote (enable button, reset text) | |
| # - cards (list of checkbox group) | |
| # --- Execute API Call --- | |
| patched_cohort_filter = _patch_range_filters_for_facet_endpoint(cohort_filter) | |
| params = { | |
| "filters": patched_cohort_filter, | |
| "facets": FACETS_STR, | |
| "pretty": "false", | |
| "format": "JSON", | |
| "size": 0, | |
| } | |
| response = requests.get(GDC_CASES_API_ENDPOINT, params=params) | |
| if not response.ok: | |
| raise Exception(f"API error: {response.status_code}\n{response.json()}") | |
| api_data = response.json()["data"] | |
| # --- Update Elements --- | |
| case_count = api_data["pagination"]["total"] | |
| active_choices = _convert_cohort_filter_to_active_selections(cohort_filter) | |
| card_updates = _convert_cohort_filter_to_cards(cohort_filter, api_data) | |
| return [ | |
| gr.update(value=f"{case_count} Cases"), # case counter | |
| gr.update(choices=active_choices, value=active_choices), # actives | |
| gr.update(interactive=True, value="⬆"), | |
| gr.update(interactive=True, value="⬇"), | |
| ] + card_updates | |
| def update_json_from_cards(*selected_filters_per_card: tuple[str]) -> str: | |
| ops = [] | |
| for card_name, selected_filters in zip(CARD_NAMES, selected_filters_per_card): | |
| # use the default values to determine card type (checkbox, range, etc) | |
| default_values = CARD_2_VALUES[card_name] | |
| if isinstance(default_values, list): | |
| # checkbox | |
| if len(selected_filters) > 0: | |
| base_values = [] | |
| for selected_value in selected_filters: | |
| base_value = _get_base_value(selected_value) | |
| base_values.append(base_value) | |
| content = { | |
| "field": CARD_2_FIELD[card_name], | |
| "value": base_values, | |
| } | |
| op = { | |
| "op": "in", | |
| "content": content, | |
| } | |
| ops.append(op) | |
| elif isinstance(default_values, dict): | |
| # range-slider, maybe other options in the future? | |
| assert ( | |
| default_values["type"] == "range" | |
| ), f"Expected range slider for card {card_name}" | |
| lo, hi = selected_filters | |
| subops = [] | |
| for val, limit, comp in [ | |
| (lo, default_values["min"], ">="), | |
| (hi, default_values["max"], "<="), | |
| ]: | |
| # only add range filter if not default | |
| if val == limit: | |
| continue | |
| subop = { | |
| "op": comp, | |
| "content": { | |
| "field": CARD_2_FIELD[card_name], | |
| "value": int(val), | |
| }, | |
| } | |
| subops.append(subop) | |
| if len(subops) > 0: | |
| ops.append({"op": "and", "content": subops}) | |
| else: | |
| raise ValueError(f"Unknown values for card {card_name}") | |
| cohort_filter = { | |
| "op": "and", | |
| "content": ops, | |
| } | |
| filter_json = json.dumps(cohort_filter, indent=4) | |
| return gr.update(value=filter_json) | |
| def update_json_from_active(active_selections: list[str]) -> str: | |
| grouped_selections = defaultdict(list) | |
| for k_v in active_selections: | |
| idx = k_v.find(": ") | |
| k, v = k_v[:idx], k_v[idx + 2 :] | |
| grouped_selections[k].append(v) | |
| # mock-up as card selections and defer to update_json_from_cards | |
| selected_filters_per_card = [] | |
| for card_name in CARD_NAMES: | |
| default_values = CARD_2_VALUES[card_name] | |
| card_name = card_name.upper() # match active selections casing | |
| if card_name not in grouped_selections: | |
| if isinstance(default_values, list): | |
| # mock-up for empty checkbox group | |
| selected_filters_per_card.append([]) | |
| elif isinstance(default_values, dict): | |
| # mock-up for default range selector | |
| selected_filters_per_card.append( | |
| ( | |
| default_values["min"], | |
| default_values["max"], | |
| ) | |
| ) | |
| else: | |
| raise ValueError(f"Unknown card type for card: {card_name}") | |
| else: | |
| selected_values = grouped_selections[card_name] | |
| if isinstance(default_values, list): | |
| # mock-up for checkbox group selections | |
| selected_filters_per_card.append(selected_values) | |
| elif isinstance(default_values, dict): | |
| # mock-up for range selector selections | |
| assert ( | |
| len(selected_values) <= 2 | |
| ), "Cannot do range op with more than 2 ops" | |
| assert all( | |
| [ | |
| "≥" in x or "≤" in x for x in selected_values | |
| ] # had to get fancy with the unicode symbols... | |
| ), "Unclear how ops besides l/gte are in active selection, did that logic change?" | |
| selected_range = dict() | |
| for x in selected_values: | |
| comp = ">=" if "≥" in x else "<=" | |
| # if the active selection logic changes (s.t. there's other ops besides l/gte), | |
| # make sure this shortcut to get the int is also checked | |
| value = int(x[1:]) | |
| if comp in selected_range: | |
| raise ValueError( | |
| f"Duplicated comparator {comp} for {card_name}" | |
| ) | |
| selected_range[comp] = value | |
| selected_filters_per_card.append( | |
| ( | |
| selected_range.get(">=", default_values["min"]), | |
| selected_range.get("<=", default_values["max"]), | |
| ) | |
| ) | |
| else: | |
| raise ValueError(f"Unknown card type for card: {card_name}") | |
| return update_json_from_cards(*selected_filters_per_card) | |
| def get_default_filter() -> str: | |
| gr.Warning( | |
| message="GDC Cohort Copilot can make mistakes. Interactively refine your search using the checkboxes.", | |
| duration=None, | |
| title="GDC Cohort Copilot Should Be Used Interactively!", | |
| ) | |
| return json.dumps({"op": "and", "content": []}, indent=4) | |
| def set_active_tab(selected_tab: str) -> list[dict]: | |
| visibles = [gr.update(visible=(tab == selected_tab)) for tab in TAB_NAMES] | |
| elem_classes = [ | |
| gr.update(variant="primary" if tab == selected_tab else "secondary") | |
| for tab in TAB_NAMES | |
| ] | |
| return visibles + elem_classes | |
| def save_user_preference(cohort_query: str, cohort_filter: str, preference: bool) -> list[dict]: # fmt: skip | |
| timestamp = datetime.now(timezone.utc).isoformat() | |
| data = { | |
| "prompt": cohort_query, | |
| "cohort_filter": json.dumps(json.loads(cohort_filter)), # remove whitespace | |
| "preference": preference, | |
| "timestamp": timestamp, | |
| } | |
| if PREF_DS: | |
| scheduler.append(data) | |
| print(f"Logged user preference data at {timestamp}") | |
| else: | |
| print( | |
| f"No preference dataset configured, " | |
| f"set PREF_DS env var to point to a HuggingFace Dataset Repo. " | |
| f"Would have logged {data}" | |
| ) | |
| # disable buttons | |
| if preference: | |
| upval = "✓" | |
| downval = "--" # whitespace seems to be escaped by gradio | |
| else: | |
| upval = "--" # whitespace seems to be escaped by gradio | |
| downval = "✗" | |
| return [ | |
| gr.update(interactive=False, value=upval), | |
| gr.update(interactive=False, value=downval), | |
| ] | |
| DOWNLOAD_CASES_JS = f""" | |
| function download_cases(filter_str) {{ | |
| const params = new URLSearchParams(); | |
| params.set('fields', 'case_id'); | |
| params.set('format', 'JSON'); | |
| params.set('size', 100000); | |
| params.set('filters', filter_str); | |
| const url = "{GDC_CASES_API_ENDPOINT}?" + params.toString(); | |
| const button = document.getElementById("download-btn"); | |
| button.innerHTML = '<div class="spinner"><\div>'; | |
| button.disabled = true; | |
| fetch(url).then(resp => {{ | |
| if (!resp.ok) throw new Error("Failed to fetch TSV."); | |
| return resp.json(); | |
| }}) | |
| .then(data => {{ | |
| const ids = data.data.hits.map(item => item.id); | |
| const text = ids.join("\\n"); | |
| const blob = new Blob([text], {{type: "text/plain"}}); | |
| return blob; | |
| }}) | |
| .then(blob => {{ | |
| const url = URL.createObjectURL(blob); | |
| const a = document.createElement('a'); | |
| a.href = url; | |
| a.download = "gdc_cohort_case_ids.tsv"; | |
| document.body.appendChild(a); | |
| a.click(); | |
| document.body.removeChild(a); | |
| URL.revokeObjectURL(url); | |
| button.innerHTML = 'Export to GDC'; | |
| button.disabled = false; | |
| }}) | |
| .catch(error => {{ | |
| alert("Download failed: " + error.message); | |
| }}); | |
| }} | |
| """ | |
| with gr.Blocks(css_paths="style.css") as demo: | |
| gr.Markdown("# GDC Cohort Copilot") | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=7): | |
| text_input = gr.Textbox( | |
| label="Describe the cohort you're looking for:", | |
| info=( | |
| "Only provide the cohort characteristics. " | |
| "Do not include extraneous text. " | |
| "For example, write 'patients with X' " | |
| "instead of 'I would like patients with X':" | |
| ), | |
| submit_btn="Generate Cohort", | |
| elem_id="description-input", | |
| placeholder="Enter a cohort description to begin...", | |
| ) | |
| with gr.Column(scale=1, min_width=150): | |
| case_counter = gr.Text( | |
| show_label=False, | |
| interactive=False, | |
| container=False, | |
| elem_id="case-counter", | |
| min_width=150, | |
| ) | |
| case_download = gr.Button( | |
| value="Export to GDC", | |
| min_width=150, | |
| elem_id="download-btn", | |
| ) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=2, min_width=250): | |
| gr.Examples( | |
| examples=EXAMPLE_INPUTS, | |
| inputs=text_input, | |
| ) | |
| with gr.Column(scale=7): | |
| json_output = gr.Code( | |
| label="Cohort Filter JSON", | |
| language="json", | |
| interactive=False, | |
| show_label=True, | |
| container=True, | |
| elem_id="json-output", | |
| ) | |
| with gr.Column(scale=1, min_width=50): | |
| gr.Markdown( | |
| "Is this correct?", | |
| elem_id="vote-label", | |
| ) | |
| upvote = gr.Button( | |
| value="⬆", | |
| min_width=50, | |
| elem_id="upvote-btn", | |
| ) | |
| downvote = gr.Button( | |
| value="⬇", | |
| min_width=50, | |
| elem_id="download-btn", | |
| ) | |
| with gr.Row(): | |
| gr.Markdown( | |
| "The generated cohort filter will autopopulate into the filter cards below. " | |
| "**<u>GDC Cohort Copilot can make mistakes!</u>** " | |
| "Refine your search using the interactive checkboxes. " | |
| "Note that many other options can be found by selecting the different tabs. " | |
| "**<u>If you'd like to help us improve our model</u>**, you can use the up or down vote button to send us feedback. " | |
| "We'll only save the current free text description, the cohort filter JSON, and your vote. " | |
| "You can also show us what the right filter should have been by manually refining it using the checkboxes, before up voting." | |
| ) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1, min_width=250): | |
| gr.Markdown("## Currently Selected Filters") | |
| with gr.Column(scale=4): | |
| active_selections = gr.CheckboxGroup( | |
| choices=[], | |
| show_label=False, | |
| interactive=True, | |
| elem_id="active-selections", | |
| ) | |
| with gr.Row(): | |
| # Tab selectors | |
| tab_buttons = [] | |
| with gr.Column(scale=1, min_width=250): | |
| for tab_name in TAB_NAMES: | |
| tab_button = gr.Button( | |
| value=tab_name, | |
| variant="primary" if tab_name == TAB_NAMES[0] else "secondary", | |
| ) | |
| tab_buttons.append(tab_button) | |
| # Filter cards | |
| tab_containers = [] | |
| filter_cards = [] | |
| for tab in CONFIG["tabs"]: | |
| visible = tab["name"] == TAB_NAMES[0] # default first card | |
| with gr.Column(scale=4, visible=visible) as tab_container: | |
| tab_containers.append(tab_container) | |
| with gr.Row(elem_classes=["card-group"]): | |
| for card in tab["cards"]: | |
| if isinstance(card["values"], list): | |
| filter_card = gr.CheckboxGroup( | |
| choices=[], | |
| label=card["name"], | |
| interactive=True, | |
| elem_classes=["filter-card"], | |
| ) | |
| else: | |
| # values is a dictionary and defines some meta options | |
| metaopts = card["values"] | |
| assert ( | |
| "type" in metaopts | |
| and metaopts["type"] == "range" | |
| and all( | |
| k in metaopts | |
| for k in [ | |
| "min", | |
| "max", | |
| ] | |
| ) | |
| ), f"Unknown meta options for {card['name']}" | |
| info = "Inclusive range" | |
| if "unit" in metaopts: | |
| info += f", units in {metaopts['unit']}" | |
| filter_card = RangeSlider( | |
| label=card["name"], | |
| info=info, | |
| minimum=metaopts["min"], | |
| maximum=metaopts["max"], | |
| step=1, # assume integer | |
| elem_classes=["filter-card", "filter-range"], | |
| ) | |
| filter_cards.append(filter_card) | |
| # Toggle card group (tab) visibility | |
| for tab_button, name in zip(tab_buttons, TAB_NAMES): | |
| tab_button.click( | |
| fn=set_active_tab, | |
| inputs=gr.State(name), | |
| outputs=tab_containers + tab_buttons, | |
| # api_name=False, | |
| show_api=False, | |
| ) | |
| # Callback for case download button | |
| case_download.click( | |
| fn=None, # apparently this isn't the same as not specifying it, even though the default is None? | |
| js=DOWNLOAD_CASES_JS, # need custom JSON to execute browser side download | |
| inputs=json_output, | |
| # api_name=False, | |
| show_api=False, | |
| ) | |
| # Enable user preference logging | |
| upvote.click( | |
| fn=save_user_preference, | |
| inputs=[text_input, json_output, gr.State(True)], | |
| outputs=[upvote, downvote], | |
| # api_name=False, | |
| show_api=False, | |
| ) | |
| downvote.click( | |
| fn=save_user_preference, | |
| inputs=[text_input, json_output, gr.State(False)], | |
| outputs=[upvote, downvote], | |
| # api_name=False, | |
| show_api=False, | |
| ) | |
| # Model generation should change the JSON filter | |
| # All other element updates cascade | |
| # This is the only API that should be exposed | |
| text_input.submit( | |
| fn=generate_filter, | |
| inputs=text_input, | |
| outputs=json_output, | |
| ) | |
| # Changing the card selections should change the JSON filter | |
| # All other element updates (including cards themselves) cascade | |
| for filter_card in filter_cards: | |
| if isinstance(filter_card, RangeSlider): | |
| filter_card.release( | |
| fn=update_json_from_cards, | |
| inputs=filter_cards, | |
| outputs=json_output, | |
| # api_name=False, | |
| show_api=False, | |
| ) | |
| else: | |
| filter_card.input( | |
| fn=update_json_from_cards, | |
| inputs=filter_cards, | |
| outputs=json_output, | |
| # api_name=False, | |
| show_api=False, | |
| ) | |
| # Changing the active selections should change the JSON filter | |
| # All other element updates (including active selections itself) cascade | |
| active_selections.input( | |
| fn=update_json_from_active, | |
| inputs=active_selections, | |
| outputs=json_output, | |
| # api_name=False, | |
| show_api=False, | |
| ) | |
| # JSON filter change executes API call and updates all elements | |
| json_output.change( | |
| fn=update_elements_from_filtered_api_call, | |
| inputs=json_output, | |
| outputs=[case_counter, active_selections, upvote, downvote] + filter_cards, | |
| # api_name=False, | |
| show_api=False, | |
| ) | |
| # Trigger initial update | |
| demo.load( | |
| fn=get_default_filter, | |
| inputs=None, | |
| outputs=json_output, | |
| # api_name=False, # this breaks the API functionality, not sure why | |
| show_api=False, # so just hide the API endpoints instead, not ideal | |
| # the weirdness with the API toggle seems true for all disabled API endpoints | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(mcp_server=True) | |