import json import os from collections import defaultdict import gradio as gr import requests import spaces import torch import yaml from gradio_rangeslider import RangeSlider from guidance import json as gen_json from guidance.models import Transformers from transformers import AutoTokenizer, GPT2LMHeadModel, set_seed from schema import GDCCohortSchema # isort: skip DEBUG = "DEBUG" in os.environ EXAMPLE_INPUTS = [ "bam files for TCGA-BRCA", "kidney or adrenal gland cancers with alcohol history", "tumor samples from male patients with acute myeloid lymphoma", ] GDC_CASES_API_ENDPOINT = "https://api.gdc.cancer.gov/cases" MODEL_NAME = "uc-ctds/gdc-cohort-llm-gpt2-s1M" TOKENIZER_NAME = MODEL_NAME AUTH_TOKEN = os.environ.get("HF_TOKEN", False) # HF_TOKEN must be set to use auth with open("config.yaml", "r") as f: CONFIG = yaml.safe_load(f) TAB_NAMES = [tab["name"] for tab in CONFIG["tabs"]] CARD_NAMES = [card["name"] for tab in CONFIG["tabs"] for card in tab["cards"]] CARD_FIELDS = [card["field"] for tab in CONFIG["tabs"] for card in tab["cards"]] CARD_2_FIELD = dict(list(zip(CARD_NAMES, CARD_FIELDS))) CARD_2_VALUES = { card["name"]: card["values"] for tab in CONFIG["tabs"] for card in tab["cards"] } FACETS_STR = ",".join( [ f.replace("cases.", "") for f, n in zip(CARD_FIELDS, CARD_NAMES) if not isinstance(CARD_2_VALUES[n], dict) # ^ skip range facets in bin counts ] ) if not DEBUG: tok = AutoTokenizer.from_pretrained(TOKENIZER_NAME, token=AUTH_TOKEN) # for some reason, pre-invoking tokenizer prevents endless generation when using guidance # opened ticket here: https://github.com/guidance-ai/guidance/issues/1322 tok("foobar") model = GPT2LMHeadModel.from_pretrained(MODEL_NAME, token=AUTH_TOKEN) model = model.to("cuda" if torch.cuda.is_available() else "cpu") model = model.eval() DUMMY_FILTER = json.dumps( { "op": "and", "content": [ { "op": "in", "content": { "field": "cases.project.project_id", "value": ["TCGA-BRCA"], }, }, { "op": "in", "content": { "field": "cases.project.program.name", "value": ["TCGA"], }, }, { "op": "and", "content": [ { "op": ">=", "content": { "field": "cases.diagnoses.age_at_diagnosis", "value": 7305, }, }, { "op": "<=", "content": { "field": "cases.diagnoses.age_at_diagnosis", "value": 14610, }, }, ], }, ], }, indent=4, ) # Generate cohort filter JSON from free text @spaces.GPU(duration=15) def generate_filter(query: str) -> str: """ Converts a free text description of a cancer cohort into a GDC structured cohort filter. Args: query (str): The free text cohort description Returns: str: JSON structured GDC cohort filter """ if DEBUG: return DUMMY_FILTER set_seed(42) lm = Transformers( model=model, tokenizer=tok, # sampling_params=SamplingParams, ) lm += query lm += gen_json( name="cohort", schema=GDCCohortSchema, temperature=0, max_tokens=1024 ) cohort_filter = lm["cohort"] cohort_filter = json.dumps(json.loads(cohort_filter), indent=4) return cohort_filter # Transform query to filter to checkbox selections (and update json box) def process_query(query): # Generate filter cohort_filter_str = generate_filter(query) cohort_filter = json.loads(cohort_filter_str) # Pre-flatten nested ops for easier mapping in next step flattened_ops = [] for op in cohort_filter["content"]: # nested `and` can only be 1 deep based on schema if op["op"] == "and": flattened_ops.extend(op["content"]) else: flattened_ops.append(op) # Prepare and validate generated filters generated_field_2_values = dict() for op in flattened_ops: assert op["op"] in [ "in", "=", "<", ">", "<=", ">=", ], f"Unknown handling for op: {op}" content = op["content"] field, value = content["field"], content["value"] # comparators are ints so can convert to g/lte by add/sub 1 if op["op"] == "<": op["op"] = "<=" value -= 1 elif op["op"] == ">": op["op"] = ">=" value += 1 elif op["op"] == "=": # convert = to <=,>= ops so it can be filled into card flattened_ops.append( { "op": "<=", "content": content, } ) flattened_ops.append( { "op": ">=", "content": content, } ) continue if op["op"] != "in": # comp ops will duplicate name, disambiguate by appending comp field += "_" + op["op"] if field in generated_field_2_values: raise ValueError(f"{field} is ambiguously duplicated") generated_field_2_values[field] = value # Map filter selections to cards card_updates = [] for card_name, card_field in zip(CARD_NAMES, CARD_FIELDS): # Need to update all cards so use all possible cards as ref default_values = CARD_2_VALUES[card_name] if isinstance(default_values, list): updated_values = [] updated_choices = default_values # reset value possible_values = set(updated_choices) if card_field in generated_field_2_values: # check ref against generated selected_values = generated_field_2_values.pop(card_field) unmatched_values = [] for selected_value in selected_values: if selected_value in possible_values: updated_values.append(selected_value) else: # model hallucination? unmatched_values.append(selected_value) if len(unmatched_values) > 0: generated_field_2_values[card_field] = unmatched_values update_obj = gr.update( choices=updated_choices, value=updated_values, # will override existing selections ) elif isinstance(default_values, dict): # range-slider, maybe other options in the future? assert ( default_values["type"] == "range" ), f"Expected range slider for card {card_name}" # Need to handle if model outputs flat range or nested range card_field_gte = card_field + "_>=" card_field_lte = card_field + "_<=" _min = default_values["min"] _max = default_values["max"] lo = generated_field_2_values.pop(card_field_gte, _min) hi = generated_field_2_values.pop(card_field_lte, _max) assert ( lo >= _min ), f"Generated lower bound ({lo}) less than minimum allowable value ({_min})" assert ( hi <= _max ), f"Generated upper bound ({hi}) greater than maximum allowable value ({_max})" update_obj = gr.update(value=(lo, hi)) else: raise ValueError(f"Unknown values for card {card_name}") card_updates.append(update_obj) # generated_field_2_values will have remaining, unmatched values # edit: updated json schema with enumerated fields prevents unmatched fields print(f"Unmatched values in model generation: {generated_field_2_values}") return card_updates + [gr.update(value=cohort_filter_str)] # Update JSON based on checkbox selections def update_json_from_cards(*selected_filters_per_card): ops = [] for card_name, selected_filters in zip(CARD_NAMES, selected_filters_per_card): # use the default values to determine card type (checkbox, range, etc) default_values = CARD_2_VALUES[card_name] if isinstance(default_values, list): # checkbox if len(selected_filters) > 0: base_values = [] for selected_value in selected_filters: base_value = get_base_value(selected_value) base_values.append(base_value) content = { "field": CARD_2_FIELD[card_name], "value": base_values, } op = { "op": "in", "content": content, } ops.append(op) elif isinstance(default_values, dict): # range-slider, maybe other options in the future? assert ( default_values["type"] == "range" ), f"Expected range slider for card {card_name}" lo, hi = selected_filters subops = [] for val, limit, comp in [ (lo, default_values["min"], ">="), (hi, default_values["max"], "<="), ]: # only add range filter if not default if val == limit: continue subop = { "op": comp, "content": { "field": CARD_2_FIELD[card_name], "value": int(val), }, } subops.append(subop) if len(subops) > 0: ops.append({"op": "and", "content": subops}) else: raise ValueError(f"Unknown values for card {card_name}") cohort_filter = { "op": "and", "content": ops, } filter_json = json.dumps(cohort_filter, indent=4) return gr.update(value=filter_json) # Execute GDC API query and prepare checkbox + case counter updates # Preserve prior selections def update_cards_with_counts(cohort_filter: str, *selected_filters_per_card): card_2_selections = dict(list(zip(CARD_NAMES, selected_filters_per_card))) # Execute GDC API query params = { "facets": FACETS_STR, "pretty": "false", "format": "JSON", "size": 0, } if cohort_filter: # patch for range selectors which use nested `and` # seems `facets` and nested `and` don't play well together # so flatten direct nested `and` for query execution only # this is equivalent since our top-level is always `and` # keeping nested `and` for presentation and model generations though temp = json.loads(cohort_filter) ops = temp["content"] new_ops = [] for op in ops: # assumes no deeper than single level nesting if op["op"] == "and": for subop in op["content"]: new_ops.append(subop) else: new_ops.append(op) temp["content"] = new_ops cohort_filter = json.dumps(temp) params["filters"] = cohort_filter response = requests.get(GDC_CASES_API_ENDPOINT, params=params) if not response.ok: raise Exception(f"API error: {response.status_code}\n{response.json()}") temp = response.json() # Update checkboxes with bin counts card_updates = [] all_counts = temp["data"]["aggregations"] for card_name in CARD_NAMES: card_field = CARD_2_FIELD[card_name] card_field = card_field.replace("cases.", "") card_values = CARD_2_VALUES[card_name] if isinstance(card_values, list): # value checkboxes choice_mapping = {} updated_choices = [] card_counts = { x["key"]: x["doc_count"] for x in all_counts[card_field]["buckets"] } for value_name in card_values: if value_name in card_counts: value_str = prepare_value_count( value_name, card_counts[value_name], ) # track possible choices to use as values choice_mapping[value_name] = value_str updated_choices.append(value_str) # Align prior selections with new choices updated_values = [] for selected_value in card_2_selections[card_name]: base_value = get_base_value(selected_value) if base_value not in choice_mapping: # Re-add choices which now presumably have 0 counts choice_mapping[base_value] = prepare_value_count(base_value, 0) updated_values.append(choice_mapping[base_value]) update_obj = gr.update( choices=updated_choices, value=updated_values, ) elif isinstance(card_values, dict): # range-slider, maybe other options in the future? assert ( card_values["type"] == "range" ), f"Expected range slider for card {card_name}" # for range slider, nothing to actually do! update_obj = gr.update() else: raise ValueError(f"Unknown values for card {card_name}") card_updates.append(update_obj) case_count = temp["data"]["pagination"]["total"] return card_updates + [gr.update(value=f"{case_count} Cases")] def update_active_selections(*selected_filters_per_card): choices = [] for card_name, selected_filters in zip(CARD_NAMES, selected_filters_per_card): # use the default values to determine card type (checkbox, range, etc) default_values = CARD_2_VALUES[card_name] if isinstance(default_values, list): # checkbox for selected_value in selected_filters: base_value = get_base_value(selected_value) choices.append(f"{card_name.upper()}: {base_value}") elif isinstance(default_values, dict): # range-slider, maybe other options in the future? assert ( default_values["type"] == "range" ), f"Expected range slider for card {card_name}" lo, hi = selected_filters if lo != default_values["min"] or hi != default_values["max"]: # only add range filter if not default lo, hi = int(lo), int(hi) choices.append(f"{card_name.upper()}: {lo}-{hi}") else: raise ValueError(f"Unknown values for card {card_name}") return gr.update(choices=choices, value=choices) def update_cards_from_active(current_selections, *selected_filters_per_card): # active selector uses a flattened list so re-agg values under card groups grouped_selections = defaultdict(set) for k_v in current_selections: idx = k_v.find(": ") k, v = k_v[:idx], k_v[idx + 2 :] grouped_selections[k].add(v) card_updates = [] for card_name, selected_filters in zip(CARD_NAMES, selected_filters_per_card): # use the default values to determine card type (checkbox, range, etc) default_values = CARD_2_VALUES[card_name] if isinstance(default_values, list): # checkbox updated_values = [] for selected_value in selected_filters: base_value = get_base_value(selected_value) if base_value in grouped_selections[card_name.upper()]: updated_values.append(selected_value) update_obj = gr.update(value=updated_values) elif isinstance(default_values, dict): # range-slider, maybe other options in the future? assert ( default_values["type"] == "range" ), f"Expected range slider for card {card_name}" # the active selector cannot change range values # so if present as an active selection, no action is needed # otherwise, reset entire range selector if card_name.upper() in grouped_selections: update_obj = gr.update() else: update_obj = gr.update( value=( default_values["min"], default_values["max"], ) ) else: raise ValueError(f"Unknown values for card {card_name}") card_updates.append(update_obj) # also remove unselected value as possible choice active_selection_update = gr.update(choices=current_selections) return [active_selection_update] + card_updates def prepare_value_count(value, count): return f"{value} [{count}]" def get_base_value(value): if " [" in value: value = value[: value.rfind(" [")] return value # Tab selection helper def set_active_tab(selected_tab): visibles = [gr.update(visible=(tab == selected_tab)) for tab in TAB_NAMES] elem_classes = [ gr.update(variant="primary" if tab == selected_tab else "secondary") for tab in TAB_NAMES ] return visibles + elem_classes DOWNLOAD_CASES_JS = f""" function download_cases(filter_str) {{ const params = new URLSearchParams(); params.set('fields', 'case_id'); params.set('format', 'JSON'); params.set('size', 100000); params.set('filters', filter_str); const url = "{GDC_CASES_API_ENDPOINT}?" + params.toString(); const button = document.getElementById("download-btn"); button.innerHTML = '
<\div>'; button.disabled = true; fetch(url).then(resp => {{ if (!resp.ok) throw new Error("Failed to fetch TSV."); return resp.json(); }}) .then(data => {{ const ids = data.data.hits.map(item => item.id); const text = ids.join("\\n"); const blob = new Blob([text], {{type: "text/plain"}}); return blob; }}) .then(blob => {{ const url = URL.createObjectURL(blob); const a = document.createElement('a'); a.href = url; a.download = "gdc_cohort_case_ids.tsv"; document.body.appendChild(a); a.click(); document.body.removeChild(a); URL.revokeObjectURL(url); button.innerHTML = 'Export to GDC'; button.disabled = false; }}) .catch(error => {{ alert("Download failed: " + error.message); }}); }} """ with gr.Blocks(css_paths="style.css") as demo: gr.Markdown("# GDC Cohort Copilot") with gr.Row(equal_height=True): with gr.Column(scale=7): text_input = gr.Textbox( label="Describe the cohort you're looking for:", info=( "Only provide the cohort characteristics. " "Do not include extraneous text. " "For example, write 'patients with X' " "instead of 'I would like patients with X':" ), submit_btn="Generate Cohort", elem_id="description-input", placeholder="Enter a cohort description to begin...", ) with gr.Column(scale=1, min_width=150): case_counter = gr.Text( show_label=False, interactive=False, container=False, elem_id="case-counter", min_width=150, ) case_download = gr.Button( value="Export to GDC", min_width=150, elem_id="download-btn", ) with gr.Row(equal_height=True): with gr.Column(scale=1, min_width=250): gr.Examples( examples=EXAMPLE_INPUTS, inputs=text_input, ) with gr.Column(scale=4): json_output = gr.Code( label="Cohort Filter JSON", value=json.dumps({"op": "and", "content": []}, indent=4), language="json", interactive=False, show_label=True, container=True, elem_id="json-output", ) with gr.Row(equal_height=True): with gr.Column(scale=1, min_width=250): gr.Markdown("## Currently Selected Filters") with gr.Column(scale=4): active_selections = gr.CheckboxGroup( choices=[], show_label=False, interactive=True, elem_id="active-selections", ) with gr.Row(): gr.Markdown( "The generated cohort filter will autopopulate into the filter cards below. " "**GDC Cohort Copilot can make mistakes!** " "Refine your search using the interactive checkboxes. " "Note that many other options can be found by selecting the different tabs on the left." ) with gr.Row(): # Tab selectors tab_buttons = [] with gr.Column(scale=1, min_width=250): for name in TAB_NAMES: tab_button = gr.Button( value=name, variant="primary" if name == TAB_NAMES[0] else "secondary", ) tab_buttons.append(tab_button) # Filter cards tab_containers = [] filter_cards = [] for tab in CONFIG["tabs"]: visible = tab["name"] == TAB_NAMES[0] # default first card with gr.Column(scale=4, visible=visible) as tab_container: tab_containers.append(tab_container) with gr.Row(elem_classes=["card-group"]): for card in tab["cards"]: if isinstance(card["values"], list): filter_card = gr.CheckboxGroup( choices=[], label=card["name"], interactive=True, elem_classes=["filter-card"], ) else: # values is a dictionary and defines some meta options metaopts = card["values"] assert ( "type" in metaopts and metaopts["type"] == "range" and all( k in metaopts for k in [ "min", "max", ] ) ), f"Unknown meta options for {card['name']}" info = "Inclusive range" if "unit" in metaopts: info += f", units in {metaopts['unit']}" filter_card = RangeSlider( label=card["name"], info=info, minimum=metaopts["min"], maximum=metaopts["max"], step=1, # assume integer elem_classes=["filter-card", "filter-range"], ) filter_cards.append(filter_card) # Assign tab buttons to toggle visibility for tab_button, name in zip(tab_buttons, TAB_NAMES): tab_button.click( fn=set_active_tab, inputs=gr.State(name), outputs=tab_containers + tab_buttons, api_name=False, ) # Enable case download case_download.click( fn=None, # apparently this isn't the same as not specifying it js=DOWNLOAD_CASES_JS, inputs=json_output, api_name=False, ) # Load initial counts on startup demo.load( fn=update_cards_with_counts, inputs=[gr.State("")] + filter_cards, outputs=filter_cards + [case_counter], api_name=False, ) # Update checkboxes on filter generation # Also update JSON based on checkboxes # - relying on checkbox update to do this fires multiple times # - also propagates new model selections after json is updated # Also this way it shows the model generated JSON text_input.submit( fn=process_query, inputs=text_input, outputs=filter_cards + [json_output], api_name=False, ).success( fn=update_active_selections, inputs=filter_cards, outputs=[active_selections], api_name=False, ) # Update JSON based on cards # Keep user `input` event listener (vs `change`) otherwise will fire multiple times # Seems like otherwise it should be cyclical, Gradio must have some logic to prevent infinite loops for filter_card in filter_cards: if isinstance(filter_card, RangeSlider): filter_card.release( fn=update_json_from_cards, inputs=filter_cards, outputs=json_output, api_name=False, ).success( fn=update_active_selections, inputs=filter_cards, outputs=[active_selections], api_name=False, ) else: filter_card.input( fn=update_json_from_cards, inputs=filter_cards, outputs=json_output, api_name=False, ).success( fn=update_active_selections, inputs=filter_cards, outputs=[active_selections], api_name=False, ) # Enable functionality of the active filter selectors active_selections.input( fn=update_cards_from_active, inputs=[active_selections] + filter_cards, outputs=[active_selections] + filter_cards, api_name=False, ).success( fn=update_json_from_cards, inputs=filter_cards, outputs=json_output, api_name=False, ) # Update checkboxes after executing filter query json_output.change( fn=update_cards_with_counts, inputs=[json_output] + filter_cards, outputs=filter_cards + [case_counter], api_name=False, ) # gr.api(generate_filter, api_name="generate_filter") if __name__ == "__main__": demo.launch(ssr_mode=False)