import json import os from collections import defaultdict import gradio as gr import requests import spaces import torch import yaml from gradio_rangeslider import RangeSlider from guidance import json as gen_json from guidance.models import Transformers from transformers import AutoTokenizer, GPT2LMHeadModel, set_seed from schema import GDCCohortSchema # isort: skip DEBUG = "DEBUG" in os.environ EXAMPLE_INPUTS = [ "bam files for TCGA-BRCA", "kidney or adrenal gland cancers with alcohol history", "tumor samples from male patients with acute myeloid lymphoma", ] GDC_CASES_API_ENDPOINT = "https://api.gdc.cancer.gov/cases" MODEL_NAME = "uc-ctds/gdc-cohort-llm-gpt2-s1M" TOKENIZER_NAME = MODEL_NAME AUTH_TOKEN = os.environ.get("HF_TOKEN", False) # HF_TOKEN must be set to use auth with open("config.yaml", "r") as f: CONFIG = yaml.safe_load(f) TAB_NAMES = [tab["name"] for tab in CONFIG["tabs"]] CARD_NAMES = [card["name"] for tab in CONFIG["tabs"] for card in tab["cards"]] CARD_FIELDS = [card["field"] for tab in CONFIG["tabs"] for card in tab["cards"]] CARD_2_FIELD = dict(list(zip(CARD_NAMES, CARD_FIELDS))) CARD_2_VALUES = { card["name"]: card["values"] for tab in CONFIG["tabs"] for card in tab["cards"] } FACETS_STR = ",".join( [ f.replace("cases.", "") for f, n in zip(CARD_FIELDS, CARD_NAMES) if not isinstance(CARD_2_VALUES[n], dict) # ^ skip range facets in bin counts ] ) if not DEBUG: tok = AutoTokenizer.from_pretrained(TOKENIZER_NAME, token=AUTH_TOKEN) # for some reason, pre-invoking tokenizer prevents endless generation when using guidance # opened ticket here: https://github.com/guidance-ai/guidance/issues/1322 tok("foobar") model = GPT2LMHeadModel.from_pretrained(MODEL_NAME, token=AUTH_TOKEN) model = model.to("cuda" if torch.cuda.is_available() else "cpu") model = model.eval() DUMMY_FILTER = json.dumps( { "op": "and", "content": [ { "op": "in", "content": { "field": "cases.project.project_id", "value": ["TCGA-BRCA"], }, }, { "op": "in", "content": { "field": "cases.project.program.name", "value": ["TCGA"], }, }, { "op": "and", "content": [ { "op": ">=", "content": { "field": "cases.diagnoses.age_at_diagnosis", "value": 7305, }, }, { "op": "<=", "content": { "field": "cases.diagnoses.age_at_diagnosis", "value": 14610, }, }, ], }, ], }, indent=4, ) # Generate cohort filter JSON from free text @spaces.GPU(duration=15) def generate_filter(query: str) -> str: """ Converts a free text description of a cancer cohort into a GDC structured cohort filter. Args: query (str): The free text cohort description Returns: str: JSON structured GDC cohort filter """ if DEBUG: return DUMMY_FILTER set_seed(42) lm = Transformers( model=model, tokenizer=tok, # sampling_params=SamplingParams, ) lm += query lm += gen_json( name="cohort", schema=GDCCohortSchema, temperature=0, max_tokens=1024 ) cohort_filter = lm["cohort"] cohort_filter = json.dumps(json.loads(cohort_filter), indent=4) return cohort_filter # Transform query to filter to checkbox selections (and update json box) def process_query(query): # Generate filter cohort_filter_str = generate_filter(query) cohort_filter = json.loads(cohort_filter_str) # Pre-flatten nested ops for easier mapping in next step flattened_ops = [] for op in cohort_filter["content"]: # nested `and` can only be 1 deep based on schema if op["op"] == "and": flattened_ops.extend(op["content"]) else: flattened_ops.append(op) # Prepare and validate generated filters generated_field_2_values = dict() for op in flattened_ops: assert op["op"] in [ "in", "=", "<", ">", "<=", ">=", ], f"Unknown handling for op: {op}" content = op["content"] field, value = content["field"], content["value"] # comparators are ints so can convert to g/lte by add/sub 1 if op["op"] == "<": op["op"] = "<=" value -= 1 elif op["op"] == ">": op["op"] = ">=" value += 1 elif op["op"] == "=": # convert = to <=,>= ops so it can be filled into card flattened_ops.append( { "op": "<=", "content": content, } ) flattened_ops.append( { "op": ">=", "content": content, } ) continue if op["op"] != "in": # comp ops will duplicate name, disambiguate by appending comp field += "_" + op["op"] if field in generated_field_2_values: raise ValueError(f"{field} is ambiguously duplicated") generated_field_2_values[field] = value # Map filter selections to cards card_updates = [] for card_name, card_field in zip(CARD_NAMES, CARD_FIELDS): # Need to update all cards so use all possible cards as ref default_values = CARD_2_VALUES[card_name] if isinstance(default_values, list): updated_values = [] updated_choices = default_values # reset value possible_values = set(updated_choices) if card_field in generated_field_2_values: # check ref against generated selected_values = generated_field_2_values.pop(card_field) unmatched_values = [] for selected_value in selected_values: if selected_value in possible_values: updated_values.append(selected_value) else: # model hallucination? unmatched_values.append(selected_value) if len(unmatched_values) > 0: generated_field_2_values[card_field] = unmatched_values update_obj = gr.update( choices=updated_choices, value=updated_values, # will override existing selections ) elif isinstance(default_values, dict): # range-slider, maybe other options in the future? assert ( default_values["type"] == "range" ), f"Expected range slider for card {card_name}" # Need to handle if model outputs flat range or nested range card_field_gte = card_field + "_>=" card_field_lte = card_field + "_<=" _min = default_values["min"] _max = default_values["max"] lo = generated_field_2_values.pop(card_field_gte, _min) hi = generated_field_2_values.pop(card_field_lte, _max) assert ( lo >= _min ), f"Generated lower bound ({lo}) less than minimum allowable value ({_min})" assert ( hi <= _max ), f"Generated upper bound ({hi}) greater than maximum allowable value ({_max})" update_obj = gr.update(value=(lo, hi)) else: raise ValueError(f"Unknown values for card {card_name}") card_updates.append(update_obj) # generated_field_2_values will have remaining, unmatched values # edit: updated json schema with enumerated fields prevents unmatched fields print(f"Unmatched values in model generation: {generated_field_2_values}") return card_updates + [gr.update(value=cohort_filter_str)] # Update JSON based on checkbox selections def update_json_from_cards(*selected_filters_per_card): ops = [] for card_name, selected_filters in zip(CARD_NAMES, selected_filters_per_card): # use the default values to determine card type (checkbox, range, etc) default_values = CARD_2_VALUES[card_name] if isinstance(default_values, list): # checkbox if len(selected_filters) > 0: base_values = [] for selected_value in selected_filters: base_value = get_base_value(selected_value) base_values.append(base_value) content = { "field": CARD_2_FIELD[card_name], "value": base_values, } op = { "op": "in", "content": content, } ops.append(op) elif isinstance(default_values, dict): # range-slider, maybe other options in the future? assert ( default_values["type"] == "range" ), f"Expected range slider for card {card_name}" lo, hi = selected_filters subops = [] for val, limit, comp in [ (lo, default_values["min"], ">="), (hi, default_values["max"], "<="), ]: # only add range filter if not default if val == limit: continue subop = { "op": comp, "content": { "field": CARD_2_FIELD[card_name], "value": int(val), }, } subops.append(subop) if len(subops) > 0: ops.append({"op": "and", "content": subops}) else: raise ValueError(f"Unknown values for card {card_name}") cohort_filter = { "op": "and", "content": ops, } filter_json = json.dumps(cohort_filter, indent=4) return gr.update(value=filter_json) # Execute GDC API query and prepare checkbox + case counter updates # Preserve prior selections def update_cards_with_counts(cohort_filter: str, *selected_filters_per_card): card_2_selections = dict(list(zip(CARD_NAMES, selected_filters_per_card))) # Execute GDC API query params = { "facets": FACETS_STR, "pretty": "false", "format": "JSON", "size": 0, } if cohort_filter: # patch for range selectors which use nested `and` # seems `facets` and nested `and` don't play well together # so flatten direct nested `and` for query execution only # this is equivalent since our top-level is always `and` # keeping nested `and` for presentation and model generations though temp = json.loads(cohort_filter) ops = temp["content"] new_ops = [] for op in ops: # assumes no deeper than single level nesting if op["op"] == "and": for subop in op["content"]: new_ops.append(subop) else: new_ops.append(op) temp["content"] = new_ops cohort_filter = json.dumps(temp) params["filters"] = cohort_filter response = requests.get(GDC_CASES_API_ENDPOINT, params=params) if not response.ok: raise Exception(f"API error: {response.status_code}\n{response.json()}") temp = response.json() # Update checkboxes with bin counts card_updates = [] all_counts = temp["data"]["aggregations"] for card_name in CARD_NAMES: card_field = CARD_2_FIELD[card_name] card_field = card_field.replace("cases.", "") card_values = CARD_2_VALUES[card_name] if isinstance(card_values, list): # value checkboxes choice_mapping = {} updated_choices = [] card_counts = { x["key"]: x["doc_count"] for x in all_counts[card_field]["buckets"] } for value_name in card_values: if value_name in card_counts: value_str = prepare_value_count( value_name, card_counts[value_name], ) # track possible choices to use as values choice_mapping[value_name] = value_str updated_choices.append(value_str) # Align prior selections with new choices updated_values = [] for selected_value in card_2_selections[card_name]: base_value = get_base_value(selected_value) if base_value not in choice_mapping: # Re-add choices which now presumably have 0 counts choice_mapping[base_value] = prepare_value_count(base_value, 0) updated_values.append(choice_mapping[base_value]) update_obj = gr.update( choices=updated_choices, value=updated_values, ) elif isinstance(card_values, dict): # range-slider, maybe other options in the future? assert ( card_values["type"] == "range" ), f"Expected range slider for card {card_name}" # for range slider, nothing to actually do! update_obj = gr.update() else: raise ValueError(f"Unknown values for card {card_name}") card_updates.append(update_obj) case_count = temp["data"]["pagination"]["total"] return card_updates + [gr.update(value=f"{case_count} Cases")] def update_active_selections(*selected_filters_per_card): choices = [] for card_name, selected_filters in zip(CARD_NAMES, selected_filters_per_card): # use the default values to determine card type (checkbox, range, etc) default_values = CARD_2_VALUES[card_name] if isinstance(default_values, list): # checkbox for selected_value in selected_filters: base_value = get_base_value(selected_value) choices.append(f"{card_name.upper()}: {base_value}") elif isinstance(default_values, dict): # range-slider, maybe other options in the future? assert ( default_values["type"] == "range" ), f"Expected range slider for card {card_name}" lo, hi = selected_filters if lo != default_values["min"] or hi != default_values["max"]: # only add range filter if not default lo, hi = int(lo), int(hi) choices.append(f"{card_name.upper()}: {lo}-{hi}") else: raise ValueError(f"Unknown values for card {card_name}") return gr.update(choices=choices, value=choices) def update_cards_from_active(current_selections, *selected_filters_per_card): # active selector uses a flattened list so re-agg values under card groups grouped_selections = defaultdict(set) for k_v in current_selections: idx = k_v.find(": ") k, v = k_v[:idx], k_v[idx + 2 :] grouped_selections[k].add(v) card_updates = [] for card_name, selected_filters in zip(CARD_NAMES, selected_filters_per_card): # use the default values to determine card type (checkbox, range, etc) default_values = CARD_2_VALUES[card_name] if isinstance(default_values, list): # checkbox updated_values = [] for selected_value in selected_filters: base_value = get_base_value(selected_value) if base_value in grouped_selections[card_name.upper()]: updated_values.append(selected_value) update_obj = gr.update(value=updated_values) elif isinstance(default_values, dict): # range-slider, maybe other options in the future? assert ( default_values["type"] == "range" ), f"Expected range slider for card {card_name}" # the active selector cannot change range values # so if present as an active selection, no action is needed # otherwise, reset entire range selector if card_name.upper() in grouped_selections: update_obj = gr.update() else: update_obj = gr.update( value=( default_values["min"], default_values["max"], ) ) else: raise ValueError(f"Unknown values for card {card_name}") card_updates.append(update_obj) # also remove unselected value as possible choice active_selection_update = gr.update(choices=current_selections) return [active_selection_update] + card_updates def prepare_value_count(value, count): return f"{value} [{count}]" def get_base_value(value): if " [" in value: value = value[: value.rfind(" [")] return value # Tab selection helper def set_active_tab(selected_tab): visibles = [gr.update(visible=(tab == selected_tab)) for tab in TAB_NAMES] elem_classes = [ gr.update(variant="primary" if tab == selected_tab else "secondary") for tab in TAB_NAMES ] return visibles + elem_classes DOWNLOAD_CASES_JS = f""" function download_cases(filter_str) {{ const params = new URLSearchParams(); params.set('fields', 'case_id'); params.set('format', 'JSON'); params.set('size', 100000); params.set('filters', filter_str); const url = "{GDC_CASES_API_ENDPOINT}?" + params.toString(); const button = document.getElementById("download-btn"); button.innerHTML = '