songs1 commited on
Commit
1dbb331
·
1 Parent(s): 5678887

wip reconstruction

Browse files
Files changed (3) hide show
  1. app.py +86 -139
  2. app.py.old +744 -0
  3. requirements.txt +1 -1
app.py CHANGED
@@ -14,7 +14,6 @@ from transformers import AutoTokenizer, GPT2LMHeadModel, set_seed
14
 
15
  from schema import GDCCohortSchema # isort: skip
16
 
17
- DEBUG = "DEBUG" in os.environ
18
  EXAMPLE_INPUTS = [
19
  "bam files for TCGA-BRCA",
20
  "kidney or adrenal gland cancers with alcohol history",
@@ -45,57 +44,11 @@ FACETS_STR = ",".join(
45
  ]
46
  )
47
 
48
- if not DEBUG:
49
- tok = AutoTokenizer.from_pretrained(TOKENIZER_NAME, token=AUTH_TOKEN)
50
- # for some reason, pre-invoking tokenizer prevents endless generation when using guidance
51
- # opened ticket here: https://github.com/guidance-ai/guidance/issues/1322
52
- tok("foobar")
53
- model = GPT2LMHeadModel.from_pretrained(MODEL_NAME, token=AUTH_TOKEN)
54
- model = model.to("cuda" if torch.cuda.is_available() else "cpu")
55
- model = model.eval()
56
 
57
-
58
- DUMMY_FILTER = json.dumps(
59
- {
60
- "op": "and",
61
- "content": [
62
- {
63
- "op": "in",
64
- "content": {
65
- "field": "cases.project.project_id",
66
- "value": ["TCGA-BRCA"],
67
- },
68
- },
69
- {
70
- "op": "in",
71
- "content": {
72
- "field": "cases.project.program.name",
73
- "value": ["TCGA"],
74
- },
75
- },
76
- {
77
- "op": "and",
78
- "content": [
79
- {
80
- "op": ">=",
81
- "content": {
82
- "field": "cases.diagnoses.age_at_diagnosis",
83
- "value": 7305,
84
- },
85
- },
86
- {
87
- "op": "<=",
88
- "content": {
89
- "field": "cases.diagnoses.age_at_diagnosis",
90
- "value": 14610,
91
- },
92
- },
93
- ],
94
- },
95
- ],
96
- },
97
- indent=4,
98
- )
99
 
100
 
101
  # Generate cohort filter JSON from free text
@@ -110,8 +63,6 @@ def generate_filter(query: str) -> str:
110
  Returns:
111
  str: JSON structured GDC cohort filter
112
  """
113
- if DEBUG:
114
- return DUMMY_FILTER
115
 
116
  set_seed(42)
117
  lm = Transformers(
@@ -525,7 +476,7 @@ function download_cases(filter_str) {{
525
  """
526
 
527
  with gr.Blocks(css_paths="style.css") as demo:
528
- gr.Markdown("# GDC Cohort Copilot")
529
 
530
  with gr.Row(equal_height=True):
531
  with gr.Column(scale=7):
@@ -593,22 +544,22 @@ with gr.Blocks(css_paths="style.css") as demo:
593
 
594
  with gr.Row():
595
  # Tab selectors
596
- tab_buttons = []
597
  with gr.Column(scale=1, min_width=250):
598
- for name in TAB_NAMES:
599
  tab_button = gr.Button(
600
- value=name,
601
- variant="primary" if name == TAB_NAMES[0] else "secondary",
602
  )
603
- tab_buttons.append(tab_button)
604
 
605
  # Filter cards
606
- tab_containers = []
607
- filter_cards = []
608
  for tab in CONFIG["tabs"]:
609
  visible = tab["name"] == TAB_NAMES[0] # default first card
610
  with gr.Column(scale=4, visible=visible) as tab_container:
611
- tab_containers.append(tab_container)
612
  with gr.Row(elem_classes=["card-group"]):
613
  for card in tab["cards"]:
614
  if isinstance(card["values"], list):
@@ -644,104 +595,100 @@ with gr.Blocks(css_paths="style.css") as demo:
644
  elem_classes=["filter-card", "filter-range"],
645
  )
646
 
647
- filter_cards.append(filter_card)
648
 
649
  # Assign tab buttons to toggle visibility
650
- for tab_button, name in zip(tab_buttons, TAB_NAMES):
651
- tab_button.click(
652
- fn=set_active_tab,
653
- inputs=gr.State(name),
654
- outputs=tab_containers + tab_buttons,
655
- api_name=False,
656
- )
657
 
658
  # Enable case download
659
- case_download.click(
660
- fn=None, # apparently this isn't the same as not specifying it
661
- js=DOWNLOAD_CASES_JS,
662
- inputs=json_output,
663
- api_name=False,
664
- )
665
 
666
  # Load initial counts on startup
667
- demo.load(
668
- fn=update_cards_with_counts,
669
- inputs=[gr.State("")] + filter_cards,
670
- outputs=filter_cards + [case_counter],
671
- api_name=False,
672
- )
673
 
674
  # Update checkboxes on filter generation
675
  # Also update JSON based on checkboxes
676
  # - relying on checkbox update to do this fires multiple times
677
  # - also propagates new model selections after json is updated
678
  # Also this way it shows the model generated JSON
679
- text_input.submit(
680
- fn=process_query,
681
- inputs=text_input,
682
- outputs=filter_cards + [json_output],
683
- api_name=False,
684
- ).success(
685
- fn=update_active_selections,
686
- inputs=filter_cards,
687
- outputs=[active_selections],
688
- api_name=False,
689
- )
690
 
691
  # Update JSON based on cards
692
  # Keep user `input` event listener (vs `change`) otherwise will fire multiple times
693
  # Seems like otherwise it should be cyclical, Gradio must have some logic to prevent infinite loops
694
- for filter_card in filter_cards:
695
- if isinstance(filter_card, RangeSlider):
696
- filter_card.release(
697
- fn=update_json_from_cards,
698
- inputs=filter_cards,
699
- outputs=json_output,
700
- api_name=False,
701
- ).success(
702
- fn=update_active_selections,
703
- inputs=filter_cards,
704
- outputs=[active_selections],
705
- api_name=False,
706
- )
707
- else:
708
- filter_card.input(
709
- fn=update_json_from_cards,
710
- inputs=filter_cards,
711
- outputs=json_output,
712
- api_name=False,
713
- ).success(
714
- fn=update_active_selections,
715
- inputs=filter_cards,
716
- outputs=[active_selections],
717
- api_name=False,
718
- )
719
 
720
  # Enable functionality of the active filter selectors
721
- active_selections.input(
722
- fn=update_cards_from_active,
723
- inputs=[active_selections] + filter_cards,
724
- outputs=[active_selections] + filter_cards,
725
- api_name=False,
726
- ).success(
727
- fn=update_json_from_cards,
728
- inputs=filter_cards,
729
- outputs=json_output,
730
- api_name=False,
731
- )
732
 
733
  # Update checkboxes after executing filter query
734
- json_output.change(
735
- fn=update_cards_with_counts,
736
- inputs=[json_output] + filter_cards,
737
- outputs=filter_cards + [case_counter],
738
- api_name=False,
739
- )
740
-
741
- def fn(a: int, b: int, c: list[str]) -> tuple[int, str]:
742
- return a + b, c[a:b]
743
 
744
- gr.api(fn, api_name="add_and_slice")
745
  # gr.api(generate_filter, api_name="generate_filter")
746
 
747
  if __name__ == "__main__":
 
14
 
15
  from schema import GDCCohortSchema # isort: skip
16
 
 
17
  EXAMPLE_INPUTS = [
18
  "bam files for TCGA-BRCA",
19
  "kidney or adrenal gland cancers with alcohol history",
 
44
  ]
45
  )
46
 
 
 
 
 
 
 
 
 
47
 
48
+ tok = AutoTokenizer.from_pretrained(TOKENIZER_NAME, token=AUTH_TOKEN)
49
+ model = GPT2LMHeadModel.from_pretrained(MODEL_NAME, token=AUTH_TOKEN)
50
+ model = model.to("cuda" if torch.cuda.is_available() else "cpu")
51
+ model = model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
 
54
  # Generate cohort filter JSON from free text
 
63
  Returns:
64
  str: JSON structured GDC cohort filter
65
  """
 
 
66
 
67
  set_seed(42)
68
  lm = Transformers(
 
476
  """
477
 
478
  with gr.Blocks(css_paths="style.css") as demo:
479
+ gr.Markdown("# GDC Cohort Copilot - UNDER CONSTRUCTION")
480
 
481
  with gr.Row(equal_height=True):
482
  with gr.Column(scale=7):
 
544
 
545
  with gr.Row():
546
  # Tab selectors
547
+ tab_buttons = dict()
548
  with gr.Column(scale=1, min_width=250):
549
+ for tab_name in TAB_NAMES:
550
  tab_button = gr.Button(
551
+ value=tab_name,
552
+ variant="primary" if tab_name == TAB_NAMES[0] else "secondary",
553
  )
554
+ tab_buttons[tab_name] = tab_button
555
 
556
  # Filter cards
557
+ tab_containers = dict()
558
+ filter_cards = dict()
559
  for tab in CONFIG["tabs"]:
560
  visible = tab["name"] == TAB_NAMES[0] # default first card
561
  with gr.Column(scale=4, visible=visible) as tab_container:
562
+ tab_containers[tab["name"]] = tab_container
563
  with gr.Row(elem_classes=["card-group"]):
564
  for card in tab["cards"]:
565
  if isinstance(card["values"], list):
 
595
  elem_classes=["filter-card", "filter-range"],
596
  )
597
 
598
+ filter_cards[card["name"]] = filter_card
599
 
600
  # Assign tab buttons to toggle visibility
601
+ # for tab_button, name in zip(tab_buttons, TAB_NAMES):
602
+ # tab_button.click(
603
+ # fn=set_active_tab,
604
+ # inputs=gr.State(name),
605
+ # outputs=tab_containers + tab_buttons,
606
+ # api_name=False,
607
+ # )
608
 
609
  # Enable case download
610
+ # case_download.click(
611
+ # fn=None, # apparently this isn't the same as not specifying it
612
+ # js=DOWNLOAD_CASES_JS,
613
+ # inputs=json_output,
614
+ # api_name=False,
615
+ # )
616
 
617
  # Load initial counts on startup
618
+ # demo.load(
619
+ # fn=update_cards_with_counts,
620
+ # inputs=[gr.State("")] + filter_cards,
621
+ # outputs=filter_cards + [case_counter],
622
+ # api_name=False,
623
+ # )
624
 
625
  # Update checkboxes on filter generation
626
  # Also update JSON based on checkboxes
627
  # - relying on checkbox update to do this fires multiple times
628
  # - also propagates new model selections after json is updated
629
  # Also this way it shows the model generated JSON
630
+ # text_input.submit(
631
+ # fn=process_query,
632
+ # inputs=text_input,
633
+ # outputs=filter_cards + [json_output],
634
+ # api_name=False,
635
+ # ).success(
636
+ # fn=update_active_selections,
637
+ # inputs=filter_cards,
638
+ # outputs=[active_selections],
639
+ # api_name=False,
640
+ # )
641
 
642
  # Update JSON based on cards
643
  # Keep user `input` event listener (vs `change`) otherwise will fire multiple times
644
  # Seems like otherwise it should be cyclical, Gradio must have some logic to prevent infinite loops
645
+ # for filter_card in filter_cards:
646
+ # if isinstance(filter_card, RangeSlider):
647
+ # filter_card.release(
648
+ # fn=update_json_from_cards,
649
+ # inputs=filter_cards,
650
+ # outputs=json_output,
651
+ # api_name=False,
652
+ # ).success(
653
+ # fn=update_active_selections,
654
+ # inputs=filter_cards,
655
+ # outputs=[active_selections],
656
+ # api_name=False,
657
+ # )
658
+ # else:
659
+ # filter_card.input(
660
+ # fn=update_json_from_cards,
661
+ # inputs=filter_cards,
662
+ # outputs=json_output,
663
+ # api_name=False,
664
+ # ).success(
665
+ # fn=update_active_selections,
666
+ # inputs=filter_cards,
667
+ # outputs=[active_selections],
668
+ # api_name=False,
669
+ # )
670
 
671
  # Enable functionality of the active filter selectors
672
+ # active_selections.input(
673
+ # fn=update_cards_from_active,
674
+ # inputs=[active_selections] + filter_cards,
675
+ # outputs=[active_selections] + filter_cards,
676
+ # api_name=False,
677
+ # ).success(
678
+ # fn=update_json_from_cards,
679
+ # inputs=filter_cards,
680
+ # outputs=json_output,
681
+ # api_name=False,
682
+ # )
683
 
684
  # Update checkboxes after executing filter query
685
+ # json_output.change(
686
+ # fn=update_cards_with_counts,
687
+ # inputs=[json_output] + filter_cards,
688
+ # outputs=filter_cards + [case_counter],
689
+ # api_name=False,
690
+ # )
 
 
 
691
 
 
692
  # gr.api(generate_filter, api_name="generate_filter")
693
 
694
  if __name__ == "__main__":
app.py.old ADDED
@@ -0,0 +1,744 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from collections import defaultdict
4
+
5
+ import gradio as gr
6
+ import requests
7
+ import spaces
8
+ import torch
9
+ import yaml
10
+ from gradio_rangeslider import RangeSlider
11
+ from guidance import json as gen_json
12
+ from guidance.models import Transformers
13
+ from transformers import AutoTokenizer, GPT2LMHeadModel, set_seed
14
+
15
+ from schema import GDCCohortSchema # isort: skip
16
+
17
+ DEBUG = "DEBUG" in os.environ
18
+ EXAMPLE_INPUTS = [
19
+ "bam files for TCGA-BRCA",
20
+ "kidney or adrenal gland cancers with alcohol history",
21
+ "tumor samples from male patients with acute myeloid lymphoma",
22
+ ]
23
+
24
+ GDC_CASES_API_ENDPOINT = "https://api.gdc.cancer.gov/cases"
25
+ MODEL_NAME = "uc-ctds/gdc-cohort-llm-gpt2-s1M"
26
+ TOKENIZER_NAME = MODEL_NAME
27
+ AUTH_TOKEN = os.environ.get("HF_TOKEN", False) # HF_TOKEN must be set to use auth
28
+
29
+ with open("config.yaml", "r") as f:
30
+ CONFIG = yaml.safe_load(f)
31
+
32
+ TAB_NAMES = [tab["name"] for tab in CONFIG["tabs"]]
33
+ CARD_NAMES = [card["name"] for tab in CONFIG["tabs"] for card in tab["cards"]]
34
+ CARD_FIELDS = [card["field"] for tab in CONFIG["tabs"] for card in tab["cards"]]
35
+ CARD_2_FIELD = dict(list(zip(CARD_NAMES, CARD_FIELDS)))
36
+ CARD_2_VALUES = {
37
+ card["name"]: card["values"] for tab in CONFIG["tabs"] for card in tab["cards"]
38
+ }
39
+ FACETS_STR = ",".join(
40
+ [
41
+ f.replace("cases.", "")
42
+ for f, n in zip(CARD_FIELDS, CARD_NAMES)
43
+ if not isinstance(CARD_2_VALUES[n], dict)
44
+ # ^ skip range facets in bin counts
45
+ ]
46
+ )
47
+
48
+ if not DEBUG:
49
+ tok = AutoTokenizer.from_pretrained(TOKENIZER_NAME, token=AUTH_TOKEN)
50
+ # for some reason, pre-invoking tokenizer prevents endless generation when using guidance
51
+ # opened ticket here: https://github.com/guidance-ai/guidance/issues/1322
52
+ tok("foobar")
53
+ model = GPT2LMHeadModel.from_pretrained(MODEL_NAME, token=AUTH_TOKEN)
54
+ model = model.to("cuda" if torch.cuda.is_available() else "cpu")
55
+ model = model.eval()
56
+
57
+
58
+ DUMMY_FILTER = json.dumps(
59
+ {
60
+ "op": "and",
61
+ "content": [
62
+ {
63
+ "op": "in",
64
+ "content": {
65
+ "field": "cases.project.project_id",
66
+ "value": ["TCGA-BRCA"],
67
+ },
68
+ },
69
+ {
70
+ "op": "in",
71
+ "content": {
72
+ "field": "cases.project.program.name",
73
+ "value": ["TCGA"],
74
+ },
75
+ },
76
+ {
77
+ "op": "and",
78
+ "content": [
79
+ {
80
+ "op": ">=",
81
+ "content": {
82
+ "field": "cases.diagnoses.age_at_diagnosis",
83
+ "value": 7305,
84
+ },
85
+ },
86
+ {
87
+ "op": "<=",
88
+ "content": {
89
+ "field": "cases.diagnoses.age_at_diagnosis",
90
+ "value": 14610,
91
+ },
92
+ },
93
+ ],
94
+ },
95
+ ],
96
+ },
97
+ indent=4,
98
+ )
99
+
100
+
101
+ # Generate cohort filter JSON from free text
102
+ @spaces.GPU(duration=15)
103
+ def generate_filter(query: str) -> str:
104
+ """
105
+ Converts a free text description of a cancer cohort into a GDC structured cohort filter.
106
+
107
+ Args:
108
+ query (str): The free text cohort description
109
+
110
+ Returns:
111
+ str: JSON structured GDC cohort filter
112
+ """
113
+ if DEBUG:
114
+ return DUMMY_FILTER
115
+
116
+ set_seed(42)
117
+ lm = Transformers(
118
+ model=model,
119
+ tokenizer=tok,
120
+ # sampling_params=SamplingParams,
121
+ )
122
+ lm += query
123
+ lm += gen_json(
124
+ name="cohort", schema=GDCCohortSchema, temperature=0, max_tokens=1024
125
+ )
126
+ cohort_filter = lm["cohort"]
127
+ cohort_filter = json.dumps(json.loads(cohort_filter), indent=4)
128
+
129
+ return cohort_filter
130
+
131
+
132
+ # Transform query to filter to checkbox selections (and update json box)
133
+ def process_query(query):
134
+ # Generate filter
135
+ cohort_filter_str = generate_filter(query)
136
+ cohort_filter = json.loads(cohort_filter_str)
137
+
138
+ # Pre-flatten nested ops for easier mapping in next step
139
+ flattened_ops = []
140
+ for op in cohort_filter["content"]:
141
+ # nested `and` can only be 1 deep based on schema
142
+ if op["op"] == "and":
143
+ flattened_ops.extend(op["content"])
144
+ else:
145
+ flattened_ops.append(op)
146
+
147
+ # Prepare and validate generated filters
148
+ generated_field_2_values = dict()
149
+ for op in flattened_ops:
150
+ assert op["op"] in [
151
+ "in",
152
+ "=",
153
+ "<",
154
+ ">",
155
+ "<=",
156
+ ">=",
157
+ ], f"Unknown handling for op: {op}"
158
+ content = op["content"]
159
+ field, value = content["field"], content["value"]
160
+ # comparators are ints so can convert to g/lte by add/sub 1
161
+ if op["op"] == "<":
162
+ op["op"] = "<="
163
+ value -= 1
164
+ elif op["op"] == ">":
165
+ op["op"] = ">="
166
+ value += 1
167
+ elif op["op"] == "=":
168
+ # convert = to <=,>= ops so it can be filled into card
169
+ flattened_ops.append(
170
+ {
171
+ "op": "<=",
172
+ "content": content,
173
+ }
174
+ )
175
+ flattened_ops.append(
176
+ {
177
+ "op": ">=",
178
+ "content": content,
179
+ }
180
+ )
181
+ continue
182
+
183
+ if op["op"] != "in":
184
+ # comp ops will duplicate name, disambiguate by appending comp
185
+ field += "_" + op["op"]
186
+
187
+ if field in generated_field_2_values:
188
+ raise ValueError(f"{field} is ambiguously duplicated")
189
+ generated_field_2_values[field] = value
190
+
191
+ # Map filter selections to cards
192
+ card_updates = []
193
+ for card_name, card_field in zip(CARD_NAMES, CARD_FIELDS):
194
+ # Need to update all cards so use all possible cards as ref
195
+ default_values = CARD_2_VALUES[card_name]
196
+ if isinstance(default_values, list):
197
+ updated_values = []
198
+ updated_choices = default_values # reset value
199
+ possible_values = set(updated_choices)
200
+ if card_field in generated_field_2_values:
201
+ # check ref against generated
202
+ selected_values = generated_field_2_values.pop(card_field)
203
+ unmatched_values = []
204
+ for selected_value in selected_values:
205
+ if selected_value in possible_values:
206
+ updated_values.append(selected_value)
207
+ else:
208
+ # model hallucination?
209
+ unmatched_values.append(selected_value)
210
+ if len(unmatched_values) > 0:
211
+ generated_field_2_values[card_field] = unmatched_values
212
+ update_obj = gr.update(
213
+ choices=updated_choices,
214
+ value=updated_values, # will override existing selections
215
+ )
216
+ elif isinstance(default_values, dict):
217
+ # range-slider, maybe other options in the future?
218
+ assert (
219
+ default_values["type"] == "range"
220
+ ), f"Expected range slider for card {card_name}"
221
+ # Need to handle if model outputs flat range or nested range
222
+ card_field_gte = card_field + "_>="
223
+ card_field_lte = card_field + "_<="
224
+ _min = default_values["min"]
225
+ _max = default_values["max"]
226
+ lo = generated_field_2_values.pop(card_field_gte, _min)
227
+ hi = generated_field_2_values.pop(card_field_lte, _max)
228
+ assert (
229
+ lo >= _min
230
+ ), f"Generated lower bound ({lo}) less than minimum allowable value ({_min})"
231
+ assert (
232
+ hi <= _max
233
+ ), f"Generated upper bound ({hi}) greater than maximum allowable value ({_max})"
234
+ update_obj = gr.update(value=(lo, hi))
235
+ else:
236
+ raise ValueError(f"Unknown values for card {card_name}")
237
+ card_updates.append(update_obj)
238
+ # generated_field_2_values will have remaining, unmatched values
239
+ # edit: updated json schema with enumerated fields prevents unmatched fields
240
+ print(f"Unmatched values in model generation: {generated_field_2_values}")
241
+ return card_updates + [gr.update(value=cohort_filter_str)]
242
+
243
+
244
+ # Update JSON based on checkbox selections
245
+ def update_json_from_cards(*selected_filters_per_card):
246
+ ops = []
247
+ for card_name, selected_filters in zip(CARD_NAMES, selected_filters_per_card):
248
+ # use the default values to determine card type (checkbox, range, etc)
249
+ default_values = CARD_2_VALUES[card_name]
250
+ if isinstance(default_values, list):
251
+ # checkbox
252
+ if len(selected_filters) > 0:
253
+ base_values = []
254
+ for selected_value in selected_filters:
255
+ base_value = get_base_value(selected_value)
256
+ base_values.append(base_value)
257
+ content = {
258
+ "field": CARD_2_FIELD[card_name],
259
+ "value": base_values,
260
+ }
261
+ op = {
262
+ "op": "in",
263
+ "content": content,
264
+ }
265
+ ops.append(op)
266
+ elif isinstance(default_values, dict):
267
+ # range-slider, maybe other options in the future?
268
+ assert (
269
+ default_values["type"] == "range"
270
+ ), f"Expected range slider for card {card_name}"
271
+ lo, hi = selected_filters
272
+ subops = []
273
+ for val, limit, comp in [
274
+ (lo, default_values["min"], ">="),
275
+ (hi, default_values["max"], "<="),
276
+ ]:
277
+ # only add range filter if not default
278
+ if val == limit:
279
+ continue
280
+ subop = {
281
+ "op": comp,
282
+ "content": {
283
+ "field": CARD_2_FIELD[card_name],
284
+ "value": int(val),
285
+ },
286
+ }
287
+ subops.append(subop)
288
+ if len(subops) > 0:
289
+ ops.append({"op": "and", "content": subops})
290
+ else:
291
+ raise ValueError(f"Unknown values for card {card_name}")
292
+
293
+ cohort_filter = {
294
+ "op": "and",
295
+ "content": ops,
296
+ }
297
+ filter_json = json.dumps(cohort_filter, indent=4)
298
+ return gr.update(value=filter_json)
299
+
300
+
301
+ # Execute GDC API query and prepare checkbox + case counter updates
302
+ # Preserve prior selections
303
+ def update_cards_with_counts(cohort_filter: str, *selected_filters_per_card):
304
+ card_2_selections = dict(list(zip(CARD_NAMES, selected_filters_per_card)))
305
+
306
+ # Execute GDC API query
307
+ params = {
308
+ "facets": FACETS_STR,
309
+ "pretty": "false",
310
+ "format": "JSON",
311
+ "size": 0,
312
+ }
313
+
314
+ if cohort_filter:
315
+ # patch for range selectors which use nested `and`
316
+ # seems `facets` and nested `and` don't play well together
317
+ # so flatten direct nested `and` for query execution only
318
+ # this is equivalent since our top-level is always `and`
319
+ # keeping nested `and` for presentation and model generations though
320
+ temp = json.loads(cohort_filter)
321
+ ops = temp["content"]
322
+ new_ops = []
323
+ for op in ops:
324
+ # assumes no deeper than single level nesting
325
+ if op["op"] == "and":
326
+ for subop in op["content"]:
327
+ new_ops.append(subop)
328
+ else:
329
+ new_ops.append(op)
330
+ temp["content"] = new_ops
331
+ cohort_filter = json.dumps(temp)
332
+ params["filters"] = cohort_filter
333
+
334
+ response = requests.get(GDC_CASES_API_ENDPOINT, params=params)
335
+ if not response.ok:
336
+ raise Exception(f"API error: {response.status_code}\n{response.json()}")
337
+ temp = response.json()
338
+
339
+ # Update checkboxes with bin counts
340
+ card_updates = []
341
+ all_counts = temp["data"]["aggregations"]
342
+ for card_name in CARD_NAMES:
343
+ card_field = CARD_2_FIELD[card_name]
344
+ card_field = card_field.replace("cases.", "")
345
+ card_values = CARD_2_VALUES[card_name]
346
+ if isinstance(card_values, list):
347
+ # value checkboxes
348
+ choice_mapping = {}
349
+ updated_choices = []
350
+ card_counts = {
351
+ x["key"]: x["doc_count"] for x in all_counts[card_field]["buckets"]
352
+ }
353
+ for value_name in card_values:
354
+ if value_name in card_counts:
355
+ value_str = prepare_value_count(
356
+ value_name,
357
+ card_counts[value_name],
358
+ )
359
+ # track possible choices to use as values
360
+ choice_mapping[value_name] = value_str
361
+ updated_choices.append(value_str)
362
+
363
+ # Align prior selections with new choices
364
+ updated_values = []
365
+ for selected_value in card_2_selections[card_name]:
366
+ base_value = get_base_value(selected_value)
367
+ if base_value not in choice_mapping:
368
+ # Re-add choices which now presumably have 0 counts
369
+ choice_mapping[base_value] = prepare_value_count(base_value, 0)
370
+ updated_values.append(choice_mapping[base_value])
371
+
372
+ update_obj = gr.update(
373
+ choices=updated_choices,
374
+ value=updated_values,
375
+ )
376
+ elif isinstance(card_values, dict):
377
+ # range-slider, maybe other options in the future?
378
+ assert (
379
+ card_values["type"] == "range"
380
+ ), f"Expected range slider for card {card_name}"
381
+ # for range slider, nothing to actually do!
382
+ update_obj = gr.update()
383
+ else:
384
+ raise ValueError(f"Unknown values for card {card_name}")
385
+
386
+ card_updates.append(update_obj)
387
+
388
+ case_count = temp["data"]["pagination"]["total"]
389
+
390
+ return card_updates + [gr.update(value=f"{case_count} Cases")]
391
+
392
+
393
+ def update_active_selections(*selected_filters_per_card):
394
+ choices = []
395
+ for card_name, selected_filters in zip(CARD_NAMES, selected_filters_per_card):
396
+ # use the default values to determine card type (checkbox, range, etc)
397
+ default_values = CARD_2_VALUES[card_name]
398
+ if isinstance(default_values, list):
399
+ # checkbox
400
+ for selected_value in selected_filters:
401
+ base_value = get_base_value(selected_value)
402
+ choices.append(f"{card_name.upper()}: {base_value}")
403
+ elif isinstance(default_values, dict):
404
+ # range-slider, maybe other options in the future?
405
+ assert (
406
+ default_values["type"] == "range"
407
+ ), f"Expected range slider for card {card_name}"
408
+ lo, hi = selected_filters
409
+ if lo != default_values["min"] or hi != default_values["max"]:
410
+ # only add range filter if not default
411
+ lo, hi = int(lo), int(hi)
412
+ choices.append(f"{card_name.upper()}: {lo}-{hi}")
413
+ else:
414
+ raise ValueError(f"Unknown values for card {card_name}")
415
+
416
+ return gr.update(choices=choices, value=choices)
417
+
418
+
419
+ def update_cards_from_active(current_selections, *selected_filters_per_card):
420
+ # active selector uses a flattened list so re-agg values under card groups
421
+ grouped_selections = defaultdict(set)
422
+ for k_v in current_selections:
423
+ idx = k_v.find(": ")
424
+ k, v = k_v[:idx], k_v[idx + 2 :]
425
+ grouped_selections[k].add(v)
426
+
427
+ card_updates = []
428
+ for card_name, selected_filters in zip(CARD_NAMES, selected_filters_per_card):
429
+ # use the default values to determine card type (checkbox, range, etc)
430
+ default_values = CARD_2_VALUES[card_name]
431
+ if isinstance(default_values, list):
432
+ # checkbox
433
+ updated_values = []
434
+ for selected_value in selected_filters:
435
+ base_value = get_base_value(selected_value)
436
+ if base_value in grouped_selections[card_name.upper()]:
437
+ updated_values.append(selected_value)
438
+ update_obj = gr.update(value=updated_values)
439
+ elif isinstance(default_values, dict):
440
+ # range-slider, maybe other options in the future?
441
+ assert (
442
+ default_values["type"] == "range"
443
+ ), f"Expected range slider for card {card_name}"
444
+ # the active selector cannot change range values
445
+ # so if present as an active selection, no action is needed
446
+ # otherwise, reset entire range selector
447
+ if card_name.upper() in grouped_selections:
448
+ update_obj = gr.update()
449
+ else:
450
+ update_obj = gr.update(
451
+ value=(
452
+ default_values["min"],
453
+ default_values["max"],
454
+ )
455
+ )
456
+ else:
457
+ raise ValueError(f"Unknown values for card {card_name}")
458
+
459
+ card_updates.append(update_obj)
460
+
461
+ # also remove unselected value as possible choice
462
+ active_selection_update = gr.update(choices=current_selections)
463
+ return [active_selection_update] + card_updates
464
+
465
+
466
+ def prepare_value_count(value, count):
467
+ return f"{value} [{count}]"
468
+
469
+
470
+ def get_base_value(value):
471
+ if " [" in value:
472
+ value = value[: value.rfind(" [")]
473
+ return value
474
+
475
+
476
+ # Tab selection helper
477
+ def set_active_tab(selected_tab):
478
+ visibles = [gr.update(visible=(tab == selected_tab)) for tab in TAB_NAMES]
479
+ elem_classes = [
480
+ gr.update(variant="primary" if tab == selected_tab else "secondary")
481
+ for tab in TAB_NAMES
482
+ ]
483
+ return visibles + elem_classes
484
+
485
+
486
+ DOWNLOAD_CASES_JS = f"""
487
+ function download_cases(filter_str) {{
488
+ const params = new URLSearchParams();
489
+ params.set('fields', 'case_id');
490
+ params.set('format', 'JSON');
491
+ params.set('size', 100000);
492
+ params.set('filters', filter_str);
493
+ const url = "{GDC_CASES_API_ENDPOINT}?" + params.toString();
494
+
495
+ const button = document.getElementById("download-btn");
496
+ button.innerHTML = '<div class="spinner"><\div>';
497
+ button.disabled = true;
498
+
499
+ fetch(url).then(resp => {{
500
+ if (!resp.ok) throw new Error("Failed to fetch TSV.");
501
+ return resp.json();
502
+ }})
503
+ .then(data => {{
504
+ const ids = data.data.hits.map(item => item.id);
505
+ const text = ids.join("\\n");
506
+ const blob = new Blob([text], {{type: "text/plain"}});
507
+ return blob;
508
+ }})
509
+ .then(blob => {{
510
+ const url = URL.createObjectURL(blob);
511
+ const a = document.createElement('a');
512
+ a.href = url;
513
+ a.download = "gdc_cohort_case_ids.tsv";
514
+ document.body.appendChild(a);
515
+ a.click();
516
+ document.body.removeChild(a);
517
+ URL.revokeObjectURL(url);
518
+ button.innerHTML = 'Export to GDC';
519
+ button.disabled = false;
520
+ }})
521
+ .catch(error => {{
522
+ alert("Download failed: " + error.message);
523
+ }});
524
+ }}
525
+ """
526
+
527
+ with gr.Blocks(css_paths="style.css") as demo:
528
+ gr.Markdown("# GDC Cohort Copilot")
529
+
530
+ with gr.Row(equal_height=True):
531
+ with gr.Column(scale=7):
532
+ text_input = gr.Textbox(
533
+ label="Describe the cohort you're looking for:",
534
+ info=(
535
+ "Only provide the cohort characteristics. "
536
+ "Do not include extraneous text. "
537
+ "For example, write 'patients with X' "
538
+ "instead of 'I would like patients with X':"
539
+ ),
540
+ submit_btn="Generate Cohort",
541
+ elem_id="description-input",
542
+ placeholder="Enter a cohort description to begin...",
543
+ )
544
+ with gr.Column(scale=1, min_width=150):
545
+ case_counter = gr.Text(
546
+ show_label=False,
547
+ interactive=False,
548
+ container=False,
549
+ elem_id="case-counter",
550
+ min_width=150,
551
+ )
552
+ case_download = gr.Button(
553
+ value="Export to GDC",
554
+ min_width=150,
555
+ elem_id="download-btn",
556
+ )
557
+
558
+ with gr.Row(equal_height=True):
559
+ with gr.Column(scale=1, min_width=250):
560
+ gr.Examples(
561
+ examples=EXAMPLE_INPUTS,
562
+ inputs=text_input,
563
+ )
564
+ with gr.Column(scale=4):
565
+ json_output = gr.Code(
566
+ label="Cohort Filter JSON",
567
+ value=json.dumps({"op": "and", "content": []}, indent=4),
568
+ language="json",
569
+ interactive=False,
570
+ show_label=True,
571
+ container=True,
572
+ elem_id="json-output",
573
+ )
574
+
575
+ with gr.Row(equal_height=True):
576
+ with gr.Column(scale=1, min_width=250):
577
+ gr.Markdown("## Currently Selected Filters")
578
+ with gr.Column(scale=4):
579
+ active_selections = gr.CheckboxGroup(
580
+ choices=[],
581
+ show_label=False,
582
+ interactive=True,
583
+ elem_id="active-selections",
584
+ )
585
+
586
+ with gr.Row():
587
+ gr.Markdown(
588
+ "The generated cohort filter will autopopulate into the filter cards below. "
589
+ "**GDC Cohort Copilot can make mistakes!** "
590
+ "Refine your search using the interactive checkboxes. "
591
+ "Note that many other options can be found by selecting the different tabs on the left."
592
+ )
593
+
594
+ with gr.Row():
595
+ # Tab selectors
596
+ tab_buttons = []
597
+ with gr.Column(scale=1, min_width=250):
598
+ for name in TAB_NAMES:
599
+ tab_button = gr.Button(
600
+ value=name,
601
+ variant="primary" if name == TAB_NAMES[0] else "secondary",
602
+ )
603
+ tab_buttons.append(tab_button)
604
+
605
+ # Filter cards
606
+ tab_containers = []
607
+ filter_cards = []
608
+ for tab in CONFIG["tabs"]:
609
+ visible = tab["name"] == TAB_NAMES[0] # default first card
610
+ with gr.Column(scale=4, visible=visible) as tab_container:
611
+ tab_containers.append(tab_container)
612
+ with gr.Row(elem_classes=["card-group"]):
613
+ for card in tab["cards"]:
614
+ if isinstance(card["values"], list):
615
+ filter_card = gr.CheckboxGroup(
616
+ choices=[],
617
+ label=card["name"],
618
+ interactive=True,
619
+ elem_classes=["filter-card"],
620
+ )
621
+ else:
622
+ # values is a dictionary and defines some meta options
623
+ metaopts = card["values"]
624
+ assert (
625
+ "type" in metaopts
626
+ and metaopts["type"] == "range"
627
+ and all(
628
+ k in metaopts
629
+ for k in [
630
+ "min",
631
+ "max",
632
+ ]
633
+ )
634
+ ), f"Unknown meta options for {card['name']}"
635
+ info = "Inclusive range"
636
+ if "unit" in metaopts:
637
+ info += f", units in {metaopts['unit']}"
638
+ filter_card = RangeSlider(
639
+ label=card["name"],
640
+ info=info,
641
+ minimum=metaopts["min"],
642
+ maximum=metaopts["max"],
643
+ step=1, # assume integer
644
+ elem_classes=["filter-card", "filter-range"],
645
+ )
646
+
647
+ filter_cards.append(filter_card)
648
+
649
+ # Assign tab buttons to toggle visibility
650
+ for tab_button, name in zip(tab_buttons, TAB_NAMES):
651
+ tab_button.click(
652
+ fn=set_active_tab,
653
+ inputs=gr.State(name),
654
+ outputs=tab_containers + tab_buttons,
655
+ api_name=False,
656
+ )
657
+
658
+ # Enable case download
659
+ case_download.click(
660
+ fn=None, # apparently this isn't the same as not specifying it
661
+ js=DOWNLOAD_CASES_JS,
662
+ inputs=json_output,
663
+ api_name=False,
664
+ )
665
+
666
+ # Load initial counts on startup
667
+ demo.load(
668
+ fn=update_cards_with_counts,
669
+ inputs=[gr.State("")] + filter_cards,
670
+ outputs=filter_cards + [case_counter],
671
+ api_name=False,
672
+ )
673
+
674
+ # Update checkboxes on filter generation
675
+ # Also update JSON based on checkboxes
676
+ # - relying on checkbox update to do this fires multiple times
677
+ # - also propagates new model selections after json is updated
678
+ # Also this way it shows the model generated JSON
679
+ text_input.submit(
680
+ fn=process_query,
681
+ inputs=text_input,
682
+ outputs=filter_cards + [json_output],
683
+ api_name=False,
684
+ ).success(
685
+ fn=update_active_selections,
686
+ inputs=filter_cards,
687
+ outputs=[active_selections],
688
+ api_name=False,
689
+ )
690
+
691
+ # Update JSON based on cards
692
+ # Keep user `input` event listener (vs `change`) otherwise will fire multiple times
693
+ # Seems like otherwise it should be cyclical, Gradio must have some logic to prevent infinite loops
694
+ for filter_card in filter_cards:
695
+ if isinstance(filter_card, RangeSlider):
696
+ filter_card.release(
697
+ fn=update_json_from_cards,
698
+ inputs=filter_cards,
699
+ outputs=json_output,
700
+ api_name=False,
701
+ ).success(
702
+ fn=update_active_selections,
703
+ inputs=filter_cards,
704
+ outputs=[active_selections],
705
+ api_name=False,
706
+ )
707
+ else:
708
+ filter_card.input(
709
+ fn=update_json_from_cards,
710
+ inputs=filter_cards,
711
+ outputs=json_output,
712
+ api_name=False,
713
+ ).success(
714
+ fn=update_active_selections,
715
+ inputs=filter_cards,
716
+ outputs=[active_selections],
717
+ api_name=False,
718
+ )
719
+
720
+ # Enable functionality of the active filter selectors
721
+ active_selections.input(
722
+ fn=update_cards_from_active,
723
+ inputs=[active_selections] + filter_cards,
724
+ outputs=[active_selections] + filter_cards,
725
+ api_name=False,
726
+ ).success(
727
+ fn=update_json_from_cards,
728
+ inputs=filter_cards,
729
+ outputs=json_output,
730
+ api_name=False,
731
+ )
732
+
733
+ # Update checkboxes after executing filter query
734
+ json_output.change(
735
+ fn=update_cards_with_counts,
736
+ inputs=[json_output] + filter_cards,
737
+ outputs=filter_cards + [case_counter],
738
+ api_name=False,
739
+ )
740
+
741
+ # gr.api(generate_filter, api_name="generate_filter")
742
+
743
+ if __name__ == "__main__":
744
+ demo.launch(ssr_mode=False)
requirements.txt CHANGED
@@ -2,7 +2,7 @@ torch==2.5.1
2
  transformers==4.50.0
3
  gradio==5.49.1
4
  mcp==1.10.1
5
- guidance==0.2.4
6
  gradio_rangeslider
7
  spaces
8
  fastapi==0.116.1
 
2
  transformers==4.50.0
3
  gradio==5.49.1
4
  mcp==1.10.1
5
+ guidance==0.3.0
6
  gradio_rangeslider
7
  spaces
8
  fastapi==0.116.1