File size: 30,008 Bytes
7792455
ab14f8c
37f8669
1477c63
 
7792455
 
 
 
f50685c
7792455
314ce90
7792455
 
 
 
37f8669
1477c63
33b5fb3
62c31e6
 
 
 
 
 
7792455
2b49949
33b5fb3
1477c63
 
7792455
 
 
 
 
 
 
 
2187395
7792455
 
 
314ce90
 
 
 
 
 
 
 
 
1477c63
 
 
 
 
 
 
 
 
 
 
 
 
 
7792455
1477c63
 
1dbb331
 
7792455
 
 
 
ec99d66
6cea09c
 
 
 
 
 
 
 
 
314ce90
7792455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1477c63
d520b4f
314ce90
d520b4f
1477c63
d520b4f
 
 
 
7792455
 
2187395
4823b22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2187395
 
 
d520b4f
b8b1f43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4823b22
7792455
84ff870
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1477c63
84ff870
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d520b4f
b8b1f43
d520b4f
b8b1f43
 
 
 
7792455
b8b1f43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314ce90
 
b8b1f43
 
314ce90
b8b1f43
314ce90
b8b1f43
314ce90
b8b1f43
314ce90
b8b1f43
 
 
 
 
 
 
 
 
 
 
 
 
 
314ce90
b8b1f43
314ce90
7792455
b8b1f43
 
 
 
7792455
d520b4f
 
 
1477c63
d520b4f
1477c63
 
 
 
 
d520b4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84ff870
d520b4f
 
 
 
b8b1f43
 
 
1477c63
 
b8b1f43
7792455
 
1477c63
d520b4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1477c63
12e9fbb
49d41a8
 
 
12e9fbb
49d41a8
12e9fbb
 
 
 
 
49d41a8
12e9fbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49d41a8
 
1477c63
e202c66
 
 
 
 
e9ef774
 
 
1477c63
7792455
 
 
 
 
 
 
 
1477c63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7792455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314ce90
7792455
 
 
 
 
 
 
 
 
1477c63
7792455
 
 
 
 
37f8669
 
 
 
 
 
7792455
 
 
 
 
 
 
 
 
 
 
 
 
314ce90
7792455
 
 
 
 
1477c63
7792455
62c31e6
7792455
 
1477c63
314ce90
7792455
314ce90
7792455
 
 
 
 
1477c63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7792455
37f8669
 
 
 
 
 
 
 
 
 
 
7792455
 
42cd864
7792455
1dbb331
7792455
1dbb331
 
7792455
42cd864
7792455
 
42cd864
 
7792455
 
 
42cd864
7792455
 
314ce90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15e5f4e
314ce90
 
42cd864
7792455
b8b1f43
42cd864
 
 
 
 
1477c63
 
42cd864
7792455
b8b1f43
42cd864
b8b1f43
 
42cd864
1477c63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42cd864
7792455
b8b1f43
 
1477c63
10b55c8
 
 
 
 
 
b8b1f43
 
d520b4f
 
 
 
 
 
1477c63
 
d520b4f
 
 
 
 
 
1477c63
 
d520b4f
7792455
b8b1f43
 
49d41a8
 
12e9fbb
49d41a8
1477c63
 
49d41a8
37f8669
b8b1f43
4823b22
b8b1f43
4823b22
1477c63
 
 
4823b22
ec99d66
b8b1f43
e9ef774
65b20cd
e9ef774
 
8369f5c
 
1477c63
e9ef774
 
7792455
76345d2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
import json
import os
from collections import defaultdict
from datetime import datetime, timezone
from pathlib import Path

import gradio as gr
import requests
import spaces
import torch
import yaml
from gradio_rangeslider import RangeSlider
from guidance import json as gen_json
from guidance.models import Transformers
from transformers import AutoTokenizer, GPT2LMHeadModel, set_seed

from schema import GDCCohortSchema  # isort: skip
from scheduler import ParquetScheduler  # isort: skip

EXAMPLE_INPUTS = [
    "bam files for TCGA-BRCA",
    "kidney or adrenal gland cancers with alcohol history",
    "tumor samples from male patients with acute myeloid lymphoma",
]

GDC_CASES_API_ENDPOINT = "https://api.gdc.cancer.gov/cases"
MODEL_NAME = "uc-ctds/gdc-cohort-llm-gpt2-s1M"
TOKENIZER_NAME = MODEL_NAME
MODEL_READ_TOKEN = os.environ.get("MODEL_READ_TOKEN", None)
DATASET_WRITE_TOKEN = os.environ.get("DATASET_WRITE_TOKEN", None)

with open("config.yaml", "r") as f:
    CONFIG = yaml.safe_load(f)

TAB_NAMES = [tab["name"] for tab in CONFIG["tabs"]]
CARD_NAMES = [card["name"] for tab in CONFIG["tabs"] for card in tab["cards"]]
CARD_FIELDS = [card["field"] for tab in CONFIG["tabs"] for card in tab["cards"]]
CARD_2_FIELD = dict(list(zip(CARD_NAMES, CARD_FIELDS)))
FIELD_2_CARD = dict(list(zip(CARD_FIELDS, CARD_NAMES)))
CARD_2_VALUES = {
    card["name"]: card["values"] for tab in CONFIG["tabs"] for card in tab["cards"]
}
FACETS_STR = ",".join(
    [
        f.replace("cases.", "")
        for f, n in zip(CARD_FIELDS, CARD_NAMES)
        if not isinstance(CARD_2_VALUES[n], dict)
        # ^ skip range facets in bin counts
    ]
)

PREF_DS = os.environ.get("PREF_DS", False)
if PREF_DS:
    assert DATASET_WRITE_TOKEN is not None
    scheduler = ParquetScheduler(
        repo_id=PREF_DS,
        token=DATASET_WRITE_TOKEN,
        schema={
            "prompt": {"_type": "Value", "dtype": "string"},
            "cohort_filter": {"_type": "Value", "dtype": "string"},
            "preference": {"_type": "Value", "dtype": "bool"},
            "timestamp": {"_type": "Value", "dtype": "string"},
        },
    )


tok = AutoTokenizer.from_pretrained(TOKENIZER_NAME, token=MODEL_READ_TOKEN)
model = GPT2LMHeadModel.from_pretrained(MODEL_NAME, token=MODEL_READ_TOKEN)
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
model = model.eval()


# Generate cohort filter JSON from free text
@spaces.GPU(duration=15)
def generate_filter(query: str) -> str:
    """
    Converts a free text description of a cancer cohort into a GDC structured cohort filter.

    Args:
        query (str): The free text cohort description

    Returns:
        str: JSON structured GDC cohort filter
    """

    set_seed(42)
    lm = Transformers(
        model=model,
        tokenizer=tok,
        # sampling_params=SamplingParams,
    )
    lm += query
    lm += gen_json(
        name="cohort", schema=GDCCohortSchema, temperature=0, max_tokens=1024
    )
    cohort_filter = lm["cohort"]
    cohort_filter = json.dumps(json.loads(cohort_filter), indent=4)

    return cohort_filter


def _prepare_value_count(value: str, count: int) -> str:
    return f"{value} [{count}]"


def _get_base_value(value_count: str) -> str:
    value = value_count
    if " [" in value:
        value = value[: value.rfind(" [")]
    return value


def _patch_range_filters_for_facet_endpoint(cohort_filter: str) -> str:
    # patch for range selectors which use nested `and`
    # seems `facets` and nested `and` don't play well together
    # so flatten direct nested `and` for query execution only
    # this is equivalent since our top-level is always `and`
    # keeping nested `and` for presentation and model generations though
    temp = json.loads(cohort_filter)
    ops = temp["content"]
    new_ops = []
    for op in ops:
        # assumes no deeper than single level nesting
        if op["op"] == "and":
            for subop in op["content"]:
                new_ops.append(subop)
        else:
            new_ops.append(op)
    temp["content"] = new_ops
    return json.dumps(temp)


def _convert_cohort_filter_to_lookup(cohort_filter: str) -> dict[str, int | list[str]]:
    # Pre-flatten nested ops for easier mapping in next step
    flattened_ops = []
    for op in json.loads(cohort_filter)["content"]:
        # nested `and` can only be 1 deep based on schema
        if op["op"] == "and":
            flattened_ops.extend(op["content"])
        else:
            flattened_ops.append(op)

    # Prepare and validate generated filters
    selected_field_2_values = dict()
    for op in flattened_ops:
        assert op["op"] in ["in", "=", "<", ">", "<=", ">="], f"Unknown handling for op: {op}"  # fmt: skip
        content = op["content"]
        field, value = content["field"], content["value"]
        if op["op"] == "=":
            # convert = to <=,>= ops so it can be filled into card
            # use flattened_ops as a queue, defer current op
            flattened_ops.append(
                {
                    "op": "<=",
                    "content": content,
                }
            )
            flattened_ops.append(
                {
                    "op": ">=",
                    "content": content,
                }
            )
            continue  # defer current op
        elif op["op"] == "<":
            # comparator values are ints so can convert to lte by sub 1
            op["op"] = "<="
            value -= 1
        elif op["op"] == ">":
            # comparator values are ints so can convert to gte by add 1
            op["op"] = ">="
            value += 1

        # comp ops will duplicate name, disambiguate by appending comp
        if op["op"] != "in":
            field += "_" + op["op"]

        # check that fields are not duplicated
        if field in selected_field_2_values:
            raise ValueError(f"{field} is ambiguously duplicated")

        selected_field_2_values[field] = value

    return selected_field_2_values


def _convert_cohort_filter_to_active_selections(cohort_filter: str) -> list[str]:
    selected_field_2_values = _convert_cohort_filter_to_lookup(cohort_filter)
    active_choices = []
    for field, values in selected_field_2_values.items():
        card_name = FIELD_2_CARD[
            field.replace("_<=", "").replace("_>=", "")  # from lookup conversion
        ]
        default_values = CARD_2_VALUES[card_name]
        if isinstance(default_values, list):
            # checkbox
            possible_values = set(default_values)
            for value in values:
                if value not in possible_values:
                    continue  # model hallucination?
                active_choices.append(f"{card_name.upper()}: {value}")
        elif isinstance(default_values, dict):
            # range-slider, maybe other options in the future?
            assert default_values["type"] == "range", f"Expected range slider for card {card_name}"  # fmt: skip
            assert isinstance(values, int), "values should be integer for range op"
            if ">=" in field:
                if values != default_values["min"]:
                    active_choices.append(f"{card_name.upper()}: ≥{values}")
            elif "<=" in field:
                if values != default_values["max"]:
                    active_choices.append(f"{card_name.upper()}: ≤{values}")
            else:
                raise ValueError(f"Unclear how field is not l/gte: {field}")
        else:
            raise ValueError(f"Unknown values for card {card_name}")

    return active_choices


def _convert_cohort_filter_to_cards(cohort_filter: str, api_data: dict) -> list[dict]:
    # create lookup to use while iterating through filter card updates
    selected_field_2_values = _convert_cohort_filter_to_lookup(cohort_filter)

    # prepare card updates, use selected values to check boxes
    # values are given by the union of selected values and bucket counts
    # (some selected values may have 0 bucket counts)
    card_updates = []
    for card_name, card_field in zip(CARD_NAMES, CARD_FIELDS):
        default_values = CARD_2_VALUES[card_name]
        if isinstance(default_values, list):
            # checkbox selector
            updated_choices = []  # the possible checkboxes
            updated_values = []  # the selected checkboxes
            other_choices = []  # separate out for sorting
            bucket_counts = api_data["aggregations"][card_field.replace("cases.", "")]["buckets"]  # fmt: skip
            bucket_counts = {x["key"]: x["doc_count"] for x in bucket_counts}
            possible_values = set(default_values)

            # selected values go first as both values and choices
            if card_field in selected_field_2_values:
                unmatched_values = []
                selected_values = selected_field_2_values.pop(card_field)
                for selected_value in selected_values:
                    if selected_value not in possible_values:
                        print(
                            f"{card_field} value {selected_value} is not in the "
                            "list of default values, is this a model hallucination?"
                        )
                        unmatched_values.append(selected_value)
                        continue  # model hallucination? distinct from value with 0 count
                    count = bucket_counts.pop(selected_value, 0)
                    value_count = _prepare_value_count(selected_value, count)
                    updated_choices.append(value_count)
                    updated_values.append(value_count)
                if len(unmatched_values) != 0:
                    # collect unmatched values back into selected_field_2_values
                    # which may otherwise be tracking unmatched fields
                    selected_field_2_values[card_field] = unmatched_values

            # fill in remaining possible values from bucket counts
            for other_choice, count in bucket_counts.items():
                if other_choice not in possible_values:
                    continue  # schema mistmatch? ie if values are added
                other_choices.append(_prepare_value_count(other_choice, count))

            update_obj = gr.update(
                choices=sorted(updated_choices) + sorted(other_choices),
                value=updated_values,  # I think the order given here preserves selection order
            )
        elif isinstance(default_values, dict):
            # range-slider, maybe other options in the future?
            # nothing to do with bucket counts for range slider
            assert (
                default_values["type"] == "range"
            ), f"Expected range slider for card {card_name}"
            # Need to handle if model outputs flat range or nested range
            card_field_gte = card_field + "_>="
            card_field_lte = card_field + "_<="
            _min = default_values["min"]
            _max = default_values["max"]
            lo = selected_field_2_values.pop(card_field_gte, _min)
            hi = selected_field_2_values.pop(card_field_lte, _max)
            assert (
                lo >= _min
            ), f"Generated lower bound ({lo}) less than minimum allowable value ({_min})"
            assert (
                hi <= _max
            ), f"Generated upper bound ({hi}) greater than maximum allowable value ({_max})"
            update_obj = gr.update(value=(lo, hi))
        else:
            raise ValueError(f"Unknown card type {card_name}")
        card_updates.append(update_obj)

    # selected_field_2_values may now have remaining, unmatched values
    # edit: updated json schema with enumerated fields should prevent unmatched fields
    if len(selected_field_2_values) != 0:
        print(f"Unmatched field/values in filter selections: {selected_field_2_values}")

    return card_updates


def update_elements_from_filtered_api_call(cohort_filter: str) -> list[dict]:
    # return updates for:
    # - counter (text)
    # - active selections (checkbox group)
    # - upvote (enable button, reset text)
    # - downvote (enable button, reset text)
    # - cards (list of checkbox group)

    # --- Execute API Call ---
    patched_cohort_filter = _patch_range_filters_for_facet_endpoint(cohort_filter)

    params = {
        "filters": patched_cohort_filter,
        "facets": FACETS_STR,
        "pretty": "false",
        "format": "JSON",
        "size": 0,
    }

    response = requests.get(GDC_CASES_API_ENDPOINT, params=params)
    if not response.ok:
        raise Exception(f"API error: {response.status_code}\n{response.json()}")
    api_data = response.json()["data"]

    # --- Update Elements ---
    case_count = api_data["pagination"]["total"]
    active_choices = _convert_cohort_filter_to_active_selections(cohort_filter)
    card_updates = _convert_cohort_filter_to_cards(cohort_filter, api_data)

    return [
        gr.update(value=f"{case_count} Cases"),  # case counter
        gr.update(choices=active_choices, value=active_choices),  # actives
        gr.update(interactive=True, value="⬆"),
        gr.update(interactive=True, value="⬇"),
    ] + card_updates


def update_json_from_cards(*selected_filters_per_card: tuple[str]) -> str:
    ops = []
    for card_name, selected_filters in zip(CARD_NAMES, selected_filters_per_card):
        # use the default values to determine card type (checkbox, range, etc)
        default_values = CARD_2_VALUES[card_name]
        if isinstance(default_values, list):
            # checkbox
            if len(selected_filters) > 0:
                base_values = []
                for selected_value in selected_filters:
                    base_value = _get_base_value(selected_value)
                    base_values.append(base_value)
                content = {
                    "field": CARD_2_FIELD[card_name],
                    "value": base_values,
                }
                op = {
                    "op": "in",
                    "content": content,
                }
                ops.append(op)
        elif isinstance(default_values, dict):
            # range-slider, maybe other options in the future?
            assert (
                default_values["type"] == "range"
            ), f"Expected range slider for card {card_name}"
            lo, hi = selected_filters
            subops = []
            for val, limit, comp in [
                (lo, default_values["min"], ">="),
                (hi, default_values["max"], "<="),
            ]:
                # only add range filter if not default
                if val == limit:
                    continue
                subop = {
                    "op": comp,
                    "content": {
                        "field": CARD_2_FIELD[card_name],
                        "value": int(val),
                    },
                }
                subops.append(subop)
            if len(subops) > 0:
                ops.append({"op": "and", "content": subops})
        else:
            raise ValueError(f"Unknown values for card {card_name}")

    cohort_filter = {
        "op": "and",
        "content": ops,
    }
    filter_json = json.dumps(cohort_filter, indent=4)
    return gr.update(value=filter_json)


def update_json_from_active(active_selections: list[str]) -> str:
    grouped_selections = defaultdict(list)
    for k_v in active_selections:
        idx = k_v.find(": ")
        k, v = k_v[:idx], k_v[idx + 2 :]
        grouped_selections[k].append(v)

    # mock-up as card selections and defer to update_json_from_cards
    selected_filters_per_card = []
    for card_name in CARD_NAMES:
        default_values = CARD_2_VALUES[card_name]
        card_name = card_name.upper()  # match active selections casing
        if card_name not in grouped_selections:
            if isinstance(default_values, list):
                # mock-up for empty checkbox group
                selected_filters_per_card.append([])
            elif isinstance(default_values, dict):
                # mock-up for default range selector
                selected_filters_per_card.append(
                    (
                        default_values["min"],
                        default_values["max"],
                    )
                )
            else:
                raise ValueError(f"Unknown card type for card: {card_name}")
        else:
            selected_values = grouped_selections[card_name]
            if isinstance(default_values, list):
                # mock-up for checkbox group selections
                selected_filters_per_card.append(selected_values)
            elif isinstance(default_values, dict):
                # mock-up for range selector selections
                assert (
                    len(selected_values) <= 2
                ), "Cannot do range op with more than 2 ops"
                assert all(
                    [
                        "≥" in x or "≤" in x for x in selected_values
                    ]  # had to get fancy with the unicode symbols...
                ), "Unclear how ops besides l/gte are in active selection, did that logic change?"
                selected_range = dict()
                for x in selected_values:
                    comp = ">=" if "≥" in x else "<="
                    # if the active selection logic changes (s.t. there's other ops besides l/gte),
                    # make sure this shortcut to get the int is also checked
                    value = int(x[1:])
                    if comp in selected_range:
                        raise ValueError(
                            f"Duplicated comparator {comp} for {card_name}"
                        )
                    selected_range[comp] = value
                selected_filters_per_card.append(
                    (
                        selected_range.get(">=", default_values["min"]),
                        selected_range.get("<=", default_values["max"]),
                    )
                )
            else:
                raise ValueError(f"Unknown card type for card: {card_name}")
    return update_json_from_cards(*selected_filters_per_card)


def get_default_filter() -> str:
    gr.Warning(
        message="GDC Cohort Copilot can make mistakes. Interactively refine your search using the checkboxes.",
        duration=None,
        title="GDC Cohort Copilot Should Be Used Interactively!",
    )
    return json.dumps({"op": "and", "content": []}, indent=4)


def set_active_tab(selected_tab: str) -> list[dict]:
    visibles = [gr.update(visible=(tab == selected_tab)) for tab in TAB_NAMES]
    elem_classes = [
        gr.update(variant="primary" if tab == selected_tab else "secondary")
        for tab in TAB_NAMES
    ]
    return visibles + elem_classes


def save_user_preference(cohort_query: str, cohort_filter: str, preference: bool) -> list[dict]:  # fmt: skip
    timestamp = datetime.now(timezone.utc).isoformat()
    data = {
        "prompt": cohort_query,
        "cohort_filter": json.dumps(json.loads(cohort_filter)),  # remove whitespace
        "preference": preference,
        "timestamp": timestamp,
    }
    if PREF_DS:
        scheduler.append(data)
        print(f"Logged user preference data at {timestamp}")
    else:
        print(
            f"No preference dataset configured, "
            f"set PREF_DS env var to point to a HuggingFace Dataset Repo. "
            f"Would have logged {data}"
        )

    # disable buttons
    if preference:
        upval = "✓"
        downval = "--"  # whitespace seems to be escaped by gradio
    else:
        upval = "--"  # whitespace seems to be escaped by gradio
        downval = "✗"
    return [
        gr.update(interactive=False, value=upval),
        gr.update(interactive=False, value=downval),
    ]


DOWNLOAD_CASES_JS = f"""
function download_cases(filter_str) {{
    const params = new URLSearchParams();
    params.set('fields', 'case_id');
    params.set('format', 'JSON');
    params.set('size', 100000);
    params.set('filters', filter_str);
    const url = "{GDC_CASES_API_ENDPOINT}?" + params.toString();

    const button = document.getElementById("download-btn");
    button.innerHTML = '<div class="spinner"><\div>';
    button.disabled = true;

    fetch(url).then(resp => {{
        if (!resp.ok) throw new Error("Failed to fetch TSV.");
        return resp.json();
    }})
    .then(data => {{
        const ids = data.data.hits.map(item => item.id);
        const text = ids.join("\\n");
        const blob = new Blob([text], {{type: "text/plain"}});
        return blob;
    }})
    .then(blob => {{
        const url = URL.createObjectURL(blob);
        const a = document.createElement('a');
        a.href = url;
        a.download = "gdc_cohort_case_ids.tsv";
        document.body.appendChild(a);
        a.click();
        document.body.removeChild(a);
        URL.revokeObjectURL(url);
        button.innerHTML = 'Export to GDC';
        button.disabled = false;
    }})
    .catch(error => {{
        alert("Download failed: " + error.message);
    }});
}}
"""

with gr.Blocks(css_paths="style.css") as demo:
    gr.Markdown("# GDC Cohort Copilot")

    with gr.Row(equal_height=True):
        with gr.Column(scale=7):
            text_input = gr.Textbox(
                label="Describe the cohort you're looking for:",
                info=(
                    "Only provide the cohort characteristics. "
                    "Do not include extraneous text. "
                    "For example, write 'patients with X' "
                    "instead of 'I would like patients with X':"
                ),
                submit_btn="Generate Cohort",
                elem_id="description-input",
                placeholder="Enter a cohort description to begin...",
            )
        with gr.Column(scale=1, min_width=150):
            case_counter = gr.Text(
                show_label=False,
                interactive=False,
                container=False,
                elem_id="case-counter",
                min_width=150,
            )
            case_download = gr.Button(
                value="Export to GDC",
                min_width=150,
                elem_id="download-btn",
            )

    with gr.Row(equal_height=True):
        with gr.Column(scale=2, min_width=250):
            gr.Examples(
                examples=EXAMPLE_INPUTS,
                inputs=text_input,
            )
        with gr.Column(scale=7):
            json_output = gr.Code(
                label="Cohort Filter JSON",
                language="json",
                interactive=False,
                show_label=True,
                container=True,
                elem_id="json-output",
            )
        with gr.Column(scale=1, min_width=50):
            gr.Markdown(
                "Is this correct?",
                elem_id="vote-label",
            )
            upvote = gr.Button(
                value="⬆",
                min_width=50,
                elem_id="upvote-btn",
            )
            downvote = gr.Button(
                value="⬇",
                min_width=50,
                elem_id="download-btn",
            )

    with gr.Row():
        gr.Markdown(
            "The generated cohort filter will autopopulate into the filter cards below. "
            "**<u>GDC Cohort Copilot can make mistakes!</u>** "
            "Refine your search using the interactive checkboxes. "
            "Note that many other options can be found by selecting the different tabs. "
            "**<u>If you'd like to help us improve our model</u>**, you can use the up or down vote button to send us feedback. "
            "We'll only save the current free text description, the cohort filter JSON, and your vote. "
            "You can also show us what the right filter should have been by manually refining it using the checkboxes, before up voting."
        )

    with gr.Row(equal_height=True):
        with gr.Column(scale=1, min_width=250):
            gr.Markdown("## Currently Selected Filters")
        with gr.Column(scale=4):
            active_selections = gr.CheckboxGroup(
                choices=[],
                show_label=False,
                interactive=True,
                elem_id="active-selections",
            )

    with gr.Row():
        # Tab selectors
        tab_buttons = []
        with gr.Column(scale=1, min_width=250):
            for tab_name in TAB_NAMES:
                tab_button = gr.Button(
                    value=tab_name,
                    variant="primary" if tab_name == TAB_NAMES[0] else "secondary",
                )
                tab_buttons.append(tab_button)

        # Filter cards
        tab_containers = []
        filter_cards = []
        for tab in CONFIG["tabs"]:
            visible = tab["name"] == TAB_NAMES[0]  # default first card
            with gr.Column(scale=4, visible=visible) as tab_container:
                tab_containers.append(tab_container)
                with gr.Row(elem_classes=["card-group"]):
                    for card in tab["cards"]:
                        if isinstance(card["values"], list):
                            filter_card = gr.CheckboxGroup(
                                choices=[],
                                label=card["name"],
                                interactive=True,
                                elem_classes=["filter-card"],
                            )
                        else:
                            # values is a dictionary and defines some meta options
                            metaopts = card["values"]
                            assert (
                                "type" in metaopts
                                and metaopts["type"] == "range"
                                and all(
                                    k in metaopts
                                    for k in [
                                        "min",
                                        "max",
                                    ]
                                )
                            ), f"Unknown meta options for {card['name']}"
                            info = "Inclusive range"
                            if "unit" in metaopts:
                                info += f", units in {metaopts['unit']}"
                            filter_card = RangeSlider(
                                label=card["name"],
                                info=info,
                                minimum=metaopts["min"],
                                maximum=metaopts["max"],
                                step=1,  # assume integer
                                elem_classes=["filter-card", "filter-range"],
                            )

                        filter_cards.append(filter_card)

    # Toggle card group (tab) visibility
    for tab_button, name in zip(tab_buttons, TAB_NAMES):
        tab_button.click(
            fn=set_active_tab,
            inputs=gr.State(name),
            outputs=tab_containers + tab_buttons,
            # api_name=False,
            show_api=False,
        )

    # Callback for case download button
    case_download.click(
        fn=None,  # apparently this isn't the same as not specifying it, even though the default is None?
        js=DOWNLOAD_CASES_JS,  # need custom JSON to execute browser side download
        inputs=json_output,
        # api_name=False,
        show_api=False,
    )

    # Enable user preference logging
    upvote.click(
        fn=save_user_preference,
        inputs=[text_input, json_output, gr.State(True)],
        outputs=[upvote, downvote],
        # api_name=False,
        show_api=False,
    )
    downvote.click(
        fn=save_user_preference,
        inputs=[text_input, json_output, gr.State(False)],
        outputs=[upvote, downvote],
        # api_name=False,
        show_api=False,
    )

    # Model generation should change the JSON filter
    # All other element updates cascade
    # This is the only API that should be exposed
    text_input.submit(
        fn=generate_filter,
        inputs=text_input,
        outputs=json_output,
    )

    # Changing the card selections should change the JSON filter
    # All other element updates (including cards themselves) cascade
    for filter_card in filter_cards:
        if isinstance(filter_card, RangeSlider):
            filter_card.release(
                fn=update_json_from_cards,
                inputs=filter_cards,
                outputs=json_output,
                # api_name=False,
                show_api=False,
            )
        else:
            filter_card.input(
                fn=update_json_from_cards,
                inputs=filter_cards,
                outputs=json_output,
                # api_name=False,
                show_api=False,
            )

    # Changing the active selections should change the JSON filter
    # All other element updates (including active selections itself) cascade
    active_selections.input(
        fn=update_json_from_active,
        inputs=active_selections,
        outputs=json_output,
        # api_name=False,
        show_api=False,
    )

    # JSON filter change executes API call and updates all elements
    json_output.change(
        fn=update_elements_from_filtered_api_call,
        inputs=json_output,
        outputs=[case_counter, active_selections, upvote, downvote] + filter_cards,
        # api_name=False,
        show_api=False,
    )

    # Trigger initial update
    demo.load(
        fn=get_default_filter,
        inputs=None,
        outputs=json_output,
        # api_name=False,  # this breaks the API functionality, not sure why
        show_api=False,  # so just hide the API endpoints instead, not ideal
        # the weirdness with the API toggle seems true for all disabled API endpoints
    )

if __name__ == "__main__":
    demo.launch(mcp_server=True)