GDC-Cohort-Copilot / app.py.old
songs1's picture
wip reconstruction
1dbb331
import json
import os
from collections import defaultdict
import gradio as gr
import requests
import spaces
import torch
import yaml
from gradio_rangeslider import RangeSlider
from guidance import json as gen_json
from guidance.models import Transformers
from transformers import AutoTokenizer, GPT2LMHeadModel, set_seed
from schema import GDCCohortSchema # isort: skip
DEBUG = "DEBUG" in os.environ
EXAMPLE_INPUTS = [
"bam files for TCGA-BRCA",
"kidney or adrenal gland cancers with alcohol history",
"tumor samples from male patients with acute myeloid lymphoma",
]
GDC_CASES_API_ENDPOINT = "https://api.gdc.cancer.gov/cases"
MODEL_NAME = "uc-ctds/gdc-cohort-llm-gpt2-s1M"
TOKENIZER_NAME = MODEL_NAME
AUTH_TOKEN = os.environ.get("HF_TOKEN", False) # HF_TOKEN must be set to use auth
with open("config.yaml", "r") as f:
CONFIG = yaml.safe_load(f)
TAB_NAMES = [tab["name"] for tab in CONFIG["tabs"]]
CARD_NAMES = [card["name"] for tab in CONFIG["tabs"] for card in tab["cards"]]
CARD_FIELDS = [card["field"] for tab in CONFIG["tabs"] for card in tab["cards"]]
CARD_2_FIELD = dict(list(zip(CARD_NAMES, CARD_FIELDS)))
CARD_2_VALUES = {
card["name"]: card["values"] for tab in CONFIG["tabs"] for card in tab["cards"]
}
FACETS_STR = ",".join(
[
f.replace("cases.", "")
for f, n in zip(CARD_FIELDS, CARD_NAMES)
if not isinstance(CARD_2_VALUES[n], dict)
# ^ skip range facets in bin counts
]
)
if not DEBUG:
tok = AutoTokenizer.from_pretrained(TOKENIZER_NAME, token=AUTH_TOKEN)
# for some reason, pre-invoking tokenizer prevents endless generation when using guidance
# opened ticket here: https://github.com/guidance-ai/guidance/issues/1322
tok("foobar")
model = GPT2LMHeadModel.from_pretrained(MODEL_NAME, token=AUTH_TOKEN)
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
model = model.eval()
DUMMY_FILTER = json.dumps(
{
"op": "and",
"content": [
{
"op": "in",
"content": {
"field": "cases.project.project_id",
"value": ["TCGA-BRCA"],
},
},
{
"op": "in",
"content": {
"field": "cases.project.program.name",
"value": ["TCGA"],
},
},
{
"op": "and",
"content": [
{
"op": ">=",
"content": {
"field": "cases.diagnoses.age_at_diagnosis",
"value": 7305,
},
},
{
"op": "<=",
"content": {
"field": "cases.diagnoses.age_at_diagnosis",
"value": 14610,
},
},
],
},
],
},
indent=4,
)
# Generate cohort filter JSON from free text
@spaces.GPU(duration=15)
def generate_filter(query: str) -> str:
"""
Converts a free text description of a cancer cohort into a GDC structured cohort filter.
Args:
query (str): The free text cohort description
Returns:
str: JSON structured GDC cohort filter
"""
if DEBUG:
return DUMMY_FILTER
set_seed(42)
lm = Transformers(
model=model,
tokenizer=tok,
# sampling_params=SamplingParams,
)
lm += query
lm += gen_json(
name="cohort", schema=GDCCohortSchema, temperature=0, max_tokens=1024
)
cohort_filter = lm["cohort"]
cohort_filter = json.dumps(json.loads(cohort_filter), indent=4)
return cohort_filter
# Transform query to filter to checkbox selections (and update json box)
def process_query(query):
# Generate filter
cohort_filter_str = generate_filter(query)
cohort_filter = json.loads(cohort_filter_str)
# Pre-flatten nested ops for easier mapping in next step
flattened_ops = []
for op in cohort_filter["content"]:
# nested `and` can only be 1 deep based on schema
if op["op"] == "and":
flattened_ops.extend(op["content"])
else:
flattened_ops.append(op)
# Prepare and validate generated filters
generated_field_2_values = dict()
for op in flattened_ops:
assert op["op"] in [
"in",
"=",
"<",
">",
"<=",
">=",
], f"Unknown handling for op: {op}"
content = op["content"]
field, value = content["field"], content["value"]
# comparators are ints so can convert to g/lte by add/sub 1
if op["op"] == "<":
op["op"] = "<="
value -= 1
elif op["op"] == ">":
op["op"] = ">="
value += 1
elif op["op"] == "=":
# convert = to <=,>= ops so it can be filled into card
flattened_ops.append(
{
"op": "<=",
"content": content,
}
)
flattened_ops.append(
{
"op": ">=",
"content": content,
}
)
continue
if op["op"] != "in":
# comp ops will duplicate name, disambiguate by appending comp
field += "_" + op["op"]
if field in generated_field_2_values:
raise ValueError(f"{field} is ambiguously duplicated")
generated_field_2_values[field] = value
# Map filter selections to cards
card_updates = []
for card_name, card_field in zip(CARD_NAMES, CARD_FIELDS):
# Need to update all cards so use all possible cards as ref
default_values = CARD_2_VALUES[card_name]
if isinstance(default_values, list):
updated_values = []
updated_choices = default_values # reset value
possible_values = set(updated_choices)
if card_field in generated_field_2_values:
# check ref against generated
selected_values = generated_field_2_values.pop(card_field)
unmatched_values = []
for selected_value in selected_values:
if selected_value in possible_values:
updated_values.append(selected_value)
else:
# model hallucination?
unmatched_values.append(selected_value)
if len(unmatched_values) > 0:
generated_field_2_values[card_field] = unmatched_values
update_obj = gr.update(
choices=updated_choices,
value=updated_values, # will override existing selections
)
elif isinstance(default_values, dict):
# range-slider, maybe other options in the future?
assert (
default_values["type"] == "range"
), f"Expected range slider for card {card_name}"
# Need to handle if model outputs flat range or nested range
card_field_gte = card_field + "_>="
card_field_lte = card_field + "_<="
_min = default_values["min"]
_max = default_values["max"]
lo = generated_field_2_values.pop(card_field_gte, _min)
hi = generated_field_2_values.pop(card_field_lte, _max)
assert (
lo >= _min
), f"Generated lower bound ({lo}) less than minimum allowable value ({_min})"
assert (
hi <= _max
), f"Generated upper bound ({hi}) greater than maximum allowable value ({_max})"
update_obj = gr.update(value=(lo, hi))
else:
raise ValueError(f"Unknown values for card {card_name}")
card_updates.append(update_obj)
# generated_field_2_values will have remaining, unmatched values
# edit: updated json schema with enumerated fields prevents unmatched fields
print(f"Unmatched values in model generation: {generated_field_2_values}")
return card_updates + [gr.update(value=cohort_filter_str)]
# Update JSON based on checkbox selections
def update_json_from_cards(*selected_filters_per_card):
ops = []
for card_name, selected_filters in zip(CARD_NAMES, selected_filters_per_card):
# use the default values to determine card type (checkbox, range, etc)
default_values = CARD_2_VALUES[card_name]
if isinstance(default_values, list):
# checkbox
if len(selected_filters) > 0:
base_values = []
for selected_value in selected_filters:
base_value = get_base_value(selected_value)
base_values.append(base_value)
content = {
"field": CARD_2_FIELD[card_name],
"value": base_values,
}
op = {
"op": "in",
"content": content,
}
ops.append(op)
elif isinstance(default_values, dict):
# range-slider, maybe other options in the future?
assert (
default_values["type"] == "range"
), f"Expected range slider for card {card_name}"
lo, hi = selected_filters
subops = []
for val, limit, comp in [
(lo, default_values["min"], ">="),
(hi, default_values["max"], "<="),
]:
# only add range filter if not default
if val == limit:
continue
subop = {
"op": comp,
"content": {
"field": CARD_2_FIELD[card_name],
"value": int(val),
},
}
subops.append(subop)
if len(subops) > 0:
ops.append({"op": "and", "content": subops})
else:
raise ValueError(f"Unknown values for card {card_name}")
cohort_filter = {
"op": "and",
"content": ops,
}
filter_json = json.dumps(cohort_filter, indent=4)
return gr.update(value=filter_json)
# Execute GDC API query and prepare checkbox + case counter updates
# Preserve prior selections
def update_cards_with_counts(cohort_filter: str, *selected_filters_per_card):
card_2_selections = dict(list(zip(CARD_NAMES, selected_filters_per_card)))
# Execute GDC API query
params = {
"facets": FACETS_STR,
"pretty": "false",
"format": "JSON",
"size": 0,
}
if cohort_filter:
# patch for range selectors which use nested `and`
# seems `facets` and nested `and` don't play well together
# so flatten direct nested `and` for query execution only
# this is equivalent since our top-level is always `and`
# keeping nested `and` for presentation and model generations though
temp = json.loads(cohort_filter)
ops = temp["content"]
new_ops = []
for op in ops:
# assumes no deeper than single level nesting
if op["op"] == "and":
for subop in op["content"]:
new_ops.append(subop)
else:
new_ops.append(op)
temp["content"] = new_ops
cohort_filter = json.dumps(temp)
params["filters"] = cohort_filter
response = requests.get(GDC_CASES_API_ENDPOINT, params=params)
if not response.ok:
raise Exception(f"API error: {response.status_code}\n{response.json()}")
temp = response.json()
# Update checkboxes with bin counts
card_updates = []
all_counts = temp["data"]["aggregations"]
for card_name in CARD_NAMES:
card_field = CARD_2_FIELD[card_name]
card_field = card_field.replace("cases.", "")
card_values = CARD_2_VALUES[card_name]
if isinstance(card_values, list):
# value checkboxes
choice_mapping = {}
updated_choices = []
card_counts = {
x["key"]: x["doc_count"] for x in all_counts[card_field]["buckets"]
}
for value_name in card_values:
if value_name in card_counts:
value_str = prepare_value_count(
value_name,
card_counts[value_name],
)
# track possible choices to use as values
choice_mapping[value_name] = value_str
updated_choices.append(value_str)
# Align prior selections with new choices
updated_values = []
for selected_value in card_2_selections[card_name]:
base_value = get_base_value(selected_value)
if base_value not in choice_mapping:
# Re-add choices which now presumably have 0 counts
choice_mapping[base_value] = prepare_value_count(base_value, 0)
updated_values.append(choice_mapping[base_value])
update_obj = gr.update(
choices=updated_choices,
value=updated_values,
)
elif isinstance(card_values, dict):
# range-slider, maybe other options in the future?
assert (
card_values["type"] == "range"
), f"Expected range slider for card {card_name}"
# for range slider, nothing to actually do!
update_obj = gr.update()
else:
raise ValueError(f"Unknown values for card {card_name}")
card_updates.append(update_obj)
case_count = temp["data"]["pagination"]["total"]
return card_updates + [gr.update(value=f"{case_count} Cases")]
def update_active_selections(*selected_filters_per_card):
choices = []
for card_name, selected_filters in zip(CARD_NAMES, selected_filters_per_card):
# use the default values to determine card type (checkbox, range, etc)
default_values = CARD_2_VALUES[card_name]
if isinstance(default_values, list):
# checkbox
for selected_value in selected_filters:
base_value = get_base_value(selected_value)
choices.append(f"{card_name.upper()}: {base_value}")
elif isinstance(default_values, dict):
# range-slider, maybe other options in the future?
assert (
default_values["type"] == "range"
), f"Expected range slider for card {card_name}"
lo, hi = selected_filters
if lo != default_values["min"] or hi != default_values["max"]:
# only add range filter if not default
lo, hi = int(lo), int(hi)
choices.append(f"{card_name.upper()}: {lo}-{hi}")
else:
raise ValueError(f"Unknown values for card {card_name}")
return gr.update(choices=choices, value=choices)
def update_cards_from_active(current_selections, *selected_filters_per_card):
# active selector uses a flattened list so re-agg values under card groups
grouped_selections = defaultdict(set)
for k_v in current_selections:
idx = k_v.find(": ")
k, v = k_v[:idx], k_v[idx + 2 :]
grouped_selections[k].add(v)
card_updates = []
for card_name, selected_filters in zip(CARD_NAMES, selected_filters_per_card):
# use the default values to determine card type (checkbox, range, etc)
default_values = CARD_2_VALUES[card_name]
if isinstance(default_values, list):
# checkbox
updated_values = []
for selected_value in selected_filters:
base_value = get_base_value(selected_value)
if base_value in grouped_selections[card_name.upper()]:
updated_values.append(selected_value)
update_obj = gr.update(value=updated_values)
elif isinstance(default_values, dict):
# range-slider, maybe other options in the future?
assert (
default_values["type"] == "range"
), f"Expected range slider for card {card_name}"
# the active selector cannot change range values
# so if present as an active selection, no action is needed
# otherwise, reset entire range selector
if card_name.upper() in grouped_selections:
update_obj = gr.update()
else:
update_obj = gr.update(
value=(
default_values["min"],
default_values["max"],
)
)
else:
raise ValueError(f"Unknown values for card {card_name}")
card_updates.append(update_obj)
# also remove unselected value as possible choice
active_selection_update = gr.update(choices=current_selections)
return [active_selection_update] + card_updates
def prepare_value_count(value, count):
return f"{value} [{count}]"
def get_base_value(value):
if " [" in value:
value = value[: value.rfind(" [")]
return value
# Tab selection helper
def set_active_tab(selected_tab):
visibles = [gr.update(visible=(tab == selected_tab)) for tab in TAB_NAMES]
elem_classes = [
gr.update(variant="primary" if tab == selected_tab else "secondary")
for tab in TAB_NAMES
]
return visibles + elem_classes
DOWNLOAD_CASES_JS = f"""
function download_cases(filter_str) {{
const params = new URLSearchParams();
params.set('fields', 'case_id');
params.set('format', 'JSON');
params.set('size', 100000);
params.set('filters', filter_str);
const url = "{GDC_CASES_API_ENDPOINT}?" + params.toString();
const button = document.getElementById("download-btn");
button.innerHTML = '<div class="spinner"><\div>';
button.disabled = true;
fetch(url).then(resp => {{
if (!resp.ok) throw new Error("Failed to fetch TSV.");
return resp.json();
}})
.then(data => {{
const ids = data.data.hits.map(item => item.id);
const text = ids.join("\\n");
const blob = new Blob([text], {{type: "text/plain"}});
return blob;
}})
.then(blob => {{
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = "gdc_cohort_case_ids.tsv";
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
URL.revokeObjectURL(url);
button.innerHTML = 'Export to GDC';
button.disabled = false;
}})
.catch(error => {{
alert("Download failed: " + error.message);
}});
}}
"""
with gr.Blocks(css_paths="style.css") as demo:
gr.Markdown("# GDC Cohort Copilot")
with gr.Row(equal_height=True):
with gr.Column(scale=7):
text_input = gr.Textbox(
label="Describe the cohort you're looking for:",
info=(
"Only provide the cohort characteristics. "
"Do not include extraneous text. "
"For example, write 'patients with X' "
"instead of 'I would like patients with X':"
),
submit_btn="Generate Cohort",
elem_id="description-input",
placeholder="Enter a cohort description to begin...",
)
with gr.Column(scale=1, min_width=150):
case_counter = gr.Text(
show_label=False,
interactive=False,
container=False,
elem_id="case-counter",
min_width=150,
)
case_download = gr.Button(
value="Export to GDC",
min_width=150,
elem_id="download-btn",
)
with gr.Row(equal_height=True):
with gr.Column(scale=1, min_width=250):
gr.Examples(
examples=EXAMPLE_INPUTS,
inputs=text_input,
)
with gr.Column(scale=4):
json_output = gr.Code(
label="Cohort Filter JSON",
value=json.dumps({"op": "and", "content": []}, indent=4),
language="json",
interactive=False,
show_label=True,
container=True,
elem_id="json-output",
)
with gr.Row(equal_height=True):
with gr.Column(scale=1, min_width=250):
gr.Markdown("## Currently Selected Filters")
with gr.Column(scale=4):
active_selections = gr.CheckboxGroup(
choices=[],
show_label=False,
interactive=True,
elem_id="active-selections",
)
with gr.Row():
gr.Markdown(
"The generated cohort filter will autopopulate into the filter cards below. "
"**GDC Cohort Copilot can make mistakes!** "
"Refine your search using the interactive checkboxes. "
"Note that many other options can be found by selecting the different tabs on the left."
)
with gr.Row():
# Tab selectors
tab_buttons = []
with gr.Column(scale=1, min_width=250):
for name in TAB_NAMES:
tab_button = gr.Button(
value=name,
variant="primary" if name == TAB_NAMES[0] else "secondary",
)
tab_buttons.append(tab_button)
# Filter cards
tab_containers = []
filter_cards = []
for tab in CONFIG["tabs"]:
visible = tab["name"] == TAB_NAMES[0] # default first card
with gr.Column(scale=4, visible=visible) as tab_container:
tab_containers.append(tab_container)
with gr.Row(elem_classes=["card-group"]):
for card in tab["cards"]:
if isinstance(card["values"], list):
filter_card = gr.CheckboxGroup(
choices=[],
label=card["name"],
interactive=True,
elem_classes=["filter-card"],
)
else:
# values is a dictionary and defines some meta options
metaopts = card["values"]
assert (
"type" in metaopts
and metaopts["type"] == "range"
and all(
k in metaopts
for k in [
"min",
"max",
]
)
), f"Unknown meta options for {card['name']}"
info = "Inclusive range"
if "unit" in metaopts:
info += f", units in {metaopts['unit']}"
filter_card = RangeSlider(
label=card["name"],
info=info,
minimum=metaopts["min"],
maximum=metaopts["max"],
step=1, # assume integer
elem_classes=["filter-card", "filter-range"],
)
filter_cards.append(filter_card)
# Assign tab buttons to toggle visibility
for tab_button, name in zip(tab_buttons, TAB_NAMES):
tab_button.click(
fn=set_active_tab,
inputs=gr.State(name),
outputs=tab_containers + tab_buttons,
api_name=False,
)
# Enable case download
case_download.click(
fn=None, # apparently this isn't the same as not specifying it
js=DOWNLOAD_CASES_JS,
inputs=json_output,
api_name=False,
)
# Load initial counts on startup
demo.load(
fn=update_cards_with_counts,
inputs=[gr.State("")] + filter_cards,
outputs=filter_cards + [case_counter],
api_name=False,
)
# Update checkboxes on filter generation
# Also update JSON based on checkboxes
# - relying on checkbox update to do this fires multiple times
# - also propagates new model selections after json is updated
# Also this way it shows the model generated JSON
text_input.submit(
fn=process_query,
inputs=text_input,
outputs=filter_cards + [json_output],
api_name=False,
).success(
fn=update_active_selections,
inputs=filter_cards,
outputs=[active_selections],
api_name=False,
)
# Update JSON based on cards
# Keep user `input` event listener (vs `change`) otherwise will fire multiple times
# Seems like otherwise it should be cyclical, Gradio must have some logic to prevent infinite loops
for filter_card in filter_cards:
if isinstance(filter_card, RangeSlider):
filter_card.release(
fn=update_json_from_cards,
inputs=filter_cards,
outputs=json_output,
api_name=False,
).success(
fn=update_active_selections,
inputs=filter_cards,
outputs=[active_selections],
api_name=False,
)
else:
filter_card.input(
fn=update_json_from_cards,
inputs=filter_cards,
outputs=json_output,
api_name=False,
).success(
fn=update_active_selections,
inputs=filter_cards,
outputs=[active_selections],
api_name=False,
)
# Enable functionality of the active filter selectors
active_selections.input(
fn=update_cards_from_active,
inputs=[active_selections] + filter_cards,
outputs=[active_selections] + filter_cards,
api_name=False,
).success(
fn=update_json_from_cards,
inputs=filter_cards,
outputs=json_output,
api_name=False,
)
# Update checkboxes after executing filter query
json_output.change(
fn=update_cards_with_counts,
inputs=[json_output] + filter_cards,
outputs=filter_cards + [case_counter],
api_name=False,
)
# gr.api(generate_filter, api_name="generate_filter")
if __name__ == "__main__":
demo.launch(ssr_mode=False)