Spaces:
Running
on
Zero
Running
on
Zero
| import json | |
| import os | |
| from collections import defaultdict | |
| import gradio as gr | |
| import requests | |
| import spaces | |
| import torch | |
| import yaml | |
| from gradio_rangeslider import RangeSlider | |
| from guidance import json as gen_json | |
| from guidance.models import Transformers | |
| from transformers import AutoTokenizer, GPT2LMHeadModel, set_seed | |
| from schema import GDCCohortSchema # isort: skip | |
| DEBUG = "DEBUG" in os.environ | |
| EXAMPLE_INPUTS = [ | |
| "bam files for TCGA-BRCA", | |
| "kidney or adrenal gland cancers with alcohol history", | |
| "tumor samples from male patients with acute myeloid lymphoma", | |
| ] | |
| GDC_CASES_API_ENDPOINT = "https://api.gdc.cancer.gov/cases" | |
| MODEL_NAME = "uc-ctds/gdc-cohort-llm-gpt2-s1M" | |
| TOKENIZER_NAME = MODEL_NAME | |
| AUTH_TOKEN = os.environ.get("HF_TOKEN", False) # HF_TOKEN must be set to use auth | |
| with open("config.yaml", "r") as f: | |
| CONFIG = yaml.safe_load(f) | |
| TAB_NAMES = [tab["name"] for tab in CONFIG["tabs"]] | |
| CARD_NAMES = [card["name"] for tab in CONFIG["tabs"] for card in tab["cards"]] | |
| CARD_FIELDS = [card["field"] for tab in CONFIG["tabs"] for card in tab["cards"]] | |
| CARD_2_FIELD = dict(list(zip(CARD_NAMES, CARD_FIELDS))) | |
| CARD_2_VALUES = { | |
| card["name"]: card["values"] for tab in CONFIG["tabs"] for card in tab["cards"] | |
| } | |
| FACETS_STR = ",".join( | |
| [ | |
| f.replace("cases.", "") | |
| for f, n in zip(CARD_FIELDS, CARD_NAMES) | |
| if not isinstance(CARD_2_VALUES[n], dict) | |
| # ^ skip range facets in bin counts | |
| ] | |
| ) | |
| if not DEBUG: | |
| tok = AutoTokenizer.from_pretrained(TOKENIZER_NAME, token=AUTH_TOKEN) | |
| # for some reason, pre-invoking tokenizer prevents endless generation when using guidance | |
| # opened ticket here: https://github.com/guidance-ai/guidance/issues/1322 | |
| tok("foobar") | |
| model = GPT2LMHeadModel.from_pretrained(MODEL_NAME, token=AUTH_TOKEN) | |
| model = model.to("cuda" if torch.cuda.is_available() else "cpu") | |
| model = model.eval() | |
| DUMMY_FILTER = json.dumps( | |
| { | |
| "op": "and", | |
| "content": [ | |
| { | |
| "op": "in", | |
| "content": { | |
| "field": "cases.project.project_id", | |
| "value": ["TCGA-BRCA"], | |
| }, | |
| }, | |
| { | |
| "op": "in", | |
| "content": { | |
| "field": "cases.project.program.name", | |
| "value": ["TCGA"], | |
| }, | |
| }, | |
| { | |
| "op": "and", | |
| "content": [ | |
| { | |
| "op": ">=", | |
| "content": { | |
| "field": "cases.diagnoses.age_at_diagnosis", | |
| "value": 7305, | |
| }, | |
| }, | |
| { | |
| "op": "<=", | |
| "content": { | |
| "field": "cases.diagnoses.age_at_diagnosis", | |
| "value": 14610, | |
| }, | |
| }, | |
| ], | |
| }, | |
| ], | |
| }, | |
| indent=4, | |
| ) | |
| # Generate cohort filter JSON from free text | |
| def generate_filter(query: str) -> str: | |
| """ | |
| Converts a free text description of a cancer cohort into a GDC structured cohort filter. | |
| Args: | |
| query (str): The free text cohort description | |
| Returns: | |
| str: JSON structured GDC cohort filter | |
| """ | |
| if DEBUG: | |
| return DUMMY_FILTER | |
| set_seed(42) | |
| lm = Transformers( | |
| model=model, | |
| tokenizer=tok, | |
| # sampling_params=SamplingParams, | |
| ) | |
| lm += query | |
| lm += gen_json( | |
| name="cohort", schema=GDCCohortSchema, temperature=0, max_tokens=1024 | |
| ) | |
| cohort_filter = lm["cohort"] | |
| cohort_filter = json.dumps(json.loads(cohort_filter), indent=4) | |
| return cohort_filter | |
| # Transform query to filter to checkbox selections (and update json box) | |
| def process_query(query): | |
| # Generate filter | |
| cohort_filter_str = generate_filter(query) | |
| cohort_filter = json.loads(cohort_filter_str) | |
| # Pre-flatten nested ops for easier mapping in next step | |
| flattened_ops = [] | |
| for op in cohort_filter["content"]: | |
| # nested `and` can only be 1 deep based on schema | |
| if op["op"] == "and": | |
| flattened_ops.extend(op["content"]) | |
| else: | |
| flattened_ops.append(op) | |
| # Prepare and validate generated filters | |
| generated_field_2_values = dict() | |
| for op in flattened_ops: | |
| assert op["op"] in [ | |
| "in", | |
| "=", | |
| "<", | |
| ">", | |
| "<=", | |
| ">=", | |
| ], f"Unknown handling for op: {op}" | |
| content = op["content"] | |
| field, value = content["field"], content["value"] | |
| # comparators are ints so can convert to g/lte by add/sub 1 | |
| if op["op"] == "<": | |
| op["op"] = "<=" | |
| value -= 1 | |
| elif op["op"] == ">": | |
| op["op"] = ">=" | |
| value += 1 | |
| elif op["op"] == "=": | |
| # convert = to <=,>= ops so it can be filled into card | |
| flattened_ops.append( | |
| { | |
| "op": "<=", | |
| "content": content, | |
| } | |
| ) | |
| flattened_ops.append( | |
| { | |
| "op": ">=", | |
| "content": content, | |
| } | |
| ) | |
| continue | |
| if op["op"] != "in": | |
| # comp ops will duplicate name, disambiguate by appending comp | |
| field += "_" + op["op"] | |
| if field in generated_field_2_values: | |
| raise ValueError(f"{field} is ambiguously duplicated") | |
| generated_field_2_values[field] = value | |
| # Map filter selections to cards | |
| card_updates = [] | |
| for card_name, card_field in zip(CARD_NAMES, CARD_FIELDS): | |
| # Need to update all cards so use all possible cards as ref | |
| default_values = CARD_2_VALUES[card_name] | |
| if isinstance(default_values, list): | |
| updated_values = [] | |
| updated_choices = default_values # reset value | |
| possible_values = set(updated_choices) | |
| if card_field in generated_field_2_values: | |
| # check ref against generated | |
| selected_values = generated_field_2_values.pop(card_field) | |
| unmatched_values = [] | |
| for selected_value in selected_values: | |
| if selected_value in possible_values: | |
| updated_values.append(selected_value) | |
| else: | |
| # model hallucination? | |
| unmatched_values.append(selected_value) | |
| if len(unmatched_values) > 0: | |
| generated_field_2_values[card_field] = unmatched_values | |
| update_obj = gr.update( | |
| choices=updated_choices, | |
| value=updated_values, # will override existing selections | |
| ) | |
| elif isinstance(default_values, dict): | |
| # range-slider, maybe other options in the future? | |
| assert ( | |
| default_values["type"] == "range" | |
| ), f"Expected range slider for card {card_name}" | |
| # Need to handle if model outputs flat range or nested range | |
| card_field_gte = card_field + "_>=" | |
| card_field_lte = card_field + "_<=" | |
| _min = default_values["min"] | |
| _max = default_values["max"] | |
| lo = generated_field_2_values.pop(card_field_gte, _min) | |
| hi = generated_field_2_values.pop(card_field_lte, _max) | |
| assert ( | |
| lo >= _min | |
| ), f"Generated lower bound ({lo}) less than minimum allowable value ({_min})" | |
| assert ( | |
| hi <= _max | |
| ), f"Generated upper bound ({hi}) greater than maximum allowable value ({_max})" | |
| update_obj = gr.update(value=(lo, hi)) | |
| else: | |
| raise ValueError(f"Unknown values for card {card_name}") | |
| card_updates.append(update_obj) | |
| # generated_field_2_values will have remaining, unmatched values | |
| # edit: updated json schema with enumerated fields prevents unmatched fields | |
| print(f"Unmatched values in model generation: {generated_field_2_values}") | |
| return card_updates + [gr.update(value=cohort_filter_str)] | |
| # Update JSON based on checkbox selections | |
| def update_json_from_cards(*selected_filters_per_card): | |
| ops = [] | |
| for card_name, selected_filters in zip(CARD_NAMES, selected_filters_per_card): | |
| # use the default values to determine card type (checkbox, range, etc) | |
| default_values = CARD_2_VALUES[card_name] | |
| if isinstance(default_values, list): | |
| # checkbox | |
| if len(selected_filters) > 0: | |
| base_values = [] | |
| for selected_value in selected_filters: | |
| base_value = get_base_value(selected_value) | |
| base_values.append(base_value) | |
| content = { | |
| "field": CARD_2_FIELD[card_name], | |
| "value": base_values, | |
| } | |
| op = { | |
| "op": "in", | |
| "content": content, | |
| } | |
| ops.append(op) | |
| elif isinstance(default_values, dict): | |
| # range-slider, maybe other options in the future? | |
| assert ( | |
| default_values["type"] == "range" | |
| ), f"Expected range slider for card {card_name}" | |
| lo, hi = selected_filters | |
| subops = [] | |
| for val, limit, comp in [ | |
| (lo, default_values["min"], ">="), | |
| (hi, default_values["max"], "<="), | |
| ]: | |
| # only add range filter if not default | |
| if val == limit: | |
| continue | |
| subop = { | |
| "op": comp, | |
| "content": { | |
| "field": CARD_2_FIELD[card_name], | |
| "value": int(val), | |
| }, | |
| } | |
| subops.append(subop) | |
| if len(subops) > 0: | |
| ops.append({"op": "and", "content": subops}) | |
| else: | |
| raise ValueError(f"Unknown values for card {card_name}") | |
| cohort_filter = { | |
| "op": "and", | |
| "content": ops, | |
| } | |
| filter_json = json.dumps(cohort_filter, indent=4) | |
| return gr.update(value=filter_json) | |
| # Execute GDC API query and prepare checkbox + case counter updates | |
| # Preserve prior selections | |
| def update_cards_with_counts(cohort_filter: str, *selected_filters_per_card): | |
| card_2_selections = dict(list(zip(CARD_NAMES, selected_filters_per_card))) | |
| # Execute GDC API query | |
| params = { | |
| "facets": FACETS_STR, | |
| "pretty": "false", | |
| "format": "JSON", | |
| "size": 0, | |
| } | |
| if cohort_filter: | |
| # patch for range selectors which use nested `and` | |
| # seems `facets` and nested `and` don't play well together | |
| # so flatten direct nested `and` for query execution only | |
| # this is equivalent since our top-level is always `and` | |
| # keeping nested `and` for presentation and model generations though | |
| temp = json.loads(cohort_filter) | |
| ops = temp["content"] | |
| new_ops = [] | |
| for op in ops: | |
| # assumes no deeper than single level nesting | |
| if op["op"] == "and": | |
| for subop in op["content"]: | |
| new_ops.append(subop) | |
| else: | |
| new_ops.append(op) | |
| temp["content"] = new_ops | |
| cohort_filter = json.dumps(temp) | |
| params["filters"] = cohort_filter | |
| response = requests.get(GDC_CASES_API_ENDPOINT, params=params) | |
| if not response.ok: | |
| raise Exception(f"API error: {response.status_code}\n{response.json()}") | |
| temp = response.json() | |
| # Update checkboxes with bin counts | |
| card_updates = [] | |
| all_counts = temp["data"]["aggregations"] | |
| for card_name in CARD_NAMES: | |
| card_field = CARD_2_FIELD[card_name] | |
| card_field = card_field.replace("cases.", "") | |
| card_values = CARD_2_VALUES[card_name] | |
| if isinstance(card_values, list): | |
| # value checkboxes | |
| choice_mapping = {} | |
| updated_choices = [] | |
| card_counts = { | |
| x["key"]: x["doc_count"] for x in all_counts[card_field]["buckets"] | |
| } | |
| for value_name in card_values: | |
| if value_name in card_counts: | |
| value_str = prepare_value_count( | |
| value_name, | |
| card_counts[value_name], | |
| ) | |
| # track possible choices to use as values | |
| choice_mapping[value_name] = value_str | |
| updated_choices.append(value_str) | |
| # Align prior selections with new choices | |
| updated_values = [] | |
| for selected_value in card_2_selections[card_name]: | |
| base_value = get_base_value(selected_value) | |
| if base_value not in choice_mapping: | |
| # Re-add choices which now presumably have 0 counts | |
| choice_mapping[base_value] = prepare_value_count(base_value, 0) | |
| updated_values.append(choice_mapping[base_value]) | |
| update_obj = gr.update( | |
| choices=updated_choices, | |
| value=updated_values, | |
| ) | |
| elif isinstance(card_values, dict): | |
| # range-slider, maybe other options in the future? | |
| assert ( | |
| card_values["type"] == "range" | |
| ), f"Expected range slider for card {card_name}" | |
| # for range slider, nothing to actually do! | |
| update_obj = gr.update() | |
| else: | |
| raise ValueError(f"Unknown values for card {card_name}") | |
| card_updates.append(update_obj) | |
| case_count = temp["data"]["pagination"]["total"] | |
| return card_updates + [gr.update(value=f"{case_count} Cases")] | |
| def update_active_selections(*selected_filters_per_card): | |
| choices = [] | |
| for card_name, selected_filters in zip(CARD_NAMES, selected_filters_per_card): | |
| # use the default values to determine card type (checkbox, range, etc) | |
| default_values = CARD_2_VALUES[card_name] | |
| if isinstance(default_values, list): | |
| # checkbox | |
| for selected_value in selected_filters: | |
| base_value = get_base_value(selected_value) | |
| choices.append(f"{card_name.upper()}: {base_value}") | |
| elif isinstance(default_values, dict): | |
| # range-slider, maybe other options in the future? | |
| assert ( | |
| default_values["type"] == "range" | |
| ), f"Expected range slider for card {card_name}" | |
| lo, hi = selected_filters | |
| if lo != default_values["min"] or hi != default_values["max"]: | |
| # only add range filter if not default | |
| lo, hi = int(lo), int(hi) | |
| choices.append(f"{card_name.upper()}: {lo}-{hi}") | |
| else: | |
| raise ValueError(f"Unknown values for card {card_name}") | |
| return gr.update(choices=choices, value=choices) | |
| def update_cards_from_active(current_selections, *selected_filters_per_card): | |
| # active selector uses a flattened list so re-agg values under card groups | |
| grouped_selections = defaultdict(set) | |
| for k_v in current_selections: | |
| idx = k_v.find(": ") | |
| k, v = k_v[:idx], k_v[idx + 2 :] | |
| grouped_selections[k].add(v) | |
| card_updates = [] | |
| for card_name, selected_filters in zip(CARD_NAMES, selected_filters_per_card): | |
| # use the default values to determine card type (checkbox, range, etc) | |
| default_values = CARD_2_VALUES[card_name] | |
| if isinstance(default_values, list): | |
| # checkbox | |
| updated_values = [] | |
| for selected_value in selected_filters: | |
| base_value = get_base_value(selected_value) | |
| if base_value in grouped_selections[card_name.upper()]: | |
| updated_values.append(selected_value) | |
| update_obj = gr.update(value=updated_values) | |
| elif isinstance(default_values, dict): | |
| # range-slider, maybe other options in the future? | |
| assert ( | |
| default_values["type"] == "range" | |
| ), f"Expected range slider for card {card_name}" | |
| # the active selector cannot change range values | |
| # so if present as an active selection, no action is needed | |
| # otherwise, reset entire range selector | |
| if card_name.upper() in grouped_selections: | |
| update_obj = gr.update() | |
| else: | |
| update_obj = gr.update( | |
| value=( | |
| default_values["min"], | |
| default_values["max"], | |
| ) | |
| ) | |
| else: | |
| raise ValueError(f"Unknown values for card {card_name}") | |
| card_updates.append(update_obj) | |
| # also remove unselected value as possible choice | |
| active_selection_update = gr.update(choices=current_selections) | |
| return [active_selection_update] + card_updates | |
| def prepare_value_count(value, count): | |
| return f"{value} [{count}]" | |
| def get_base_value(value): | |
| if " [" in value: | |
| value = value[: value.rfind(" [")] | |
| return value | |
| # Tab selection helper | |
| def set_active_tab(selected_tab): | |
| visibles = [gr.update(visible=(tab == selected_tab)) for tab in TAB_NAMES] | |
| elem_classes = [ | |
| gr.update(variant="primary" if tab == selected_tab else "secondary") | |
| for tab in TAB_NAMES | |
| ] | |
| return visibles + elem_classes | |
| DOWNLOAD_CASES_JS = f""" | |
| function download_cases(filter_str) {{ | |
| const params = new URLSearchParams(); | |
| params.set('fields', 'case_id'); | |
| params.set('format', 'JSON'); | |
| params.set('size', 100000); | |
| params.set('filters', filter_str); | |
| const url = "{GDC_CASES_API_ENDPOINT}?" + params.toString(); | |
| const button = document.getElementById("download-btn"); | |
| button.innerHTML = '<div class="spinner"><\div>'; | |
| button.disabled = true; | |
| fetch(url).then(resp => {{ | |
| if (!resp.ok) throw new Error("Failed to fetch TSV."); | |
| return resp.json(); | |
| }}) | |
| .then(data => {{ | |
| const ids = data.data.hits.map(item => item.id); | |
| const text = ids.join("\\n"); | |
| const blob = new Blob([text], {{type: "text/plain"}}); | |
| return blob; | |
| }}) | |
| .then(blob => {{ | |
| const url = URL.createObjectURL(blob); | |
| const a = document.createElement('a'); | |
| a.href = url; | |
| a.download = "gdc_cohort_case_ids.tsv"; | |
| document.body.appendChild(a); | |
| a.click(); | |
| document.body.removeChild(a); | |
| URL.revokeObjectURL(url); | |
| button.innerHTML = 'Export to GDC'; | |
| button.disabled = false; | |
| }}) | |
| .catch(error => {{ | |
| alert("Download failed: " + error.message); | |
| }}); | |
| }} | |
| """ | |
| with gr.Blocks(css_paths="style.css") as demo: | |
| gr.Markdown("# GDC Cohort Copilot") | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=7): | |
| text_input = gr.Textbox( | |
| label="Describe the cohort you're looking for:", | |
| info=( | |
| "Only provide the cohort characteristics. " | |
| "Do not include extraneous text. " | |
| "For example, write 'patients with X' " | |
| "instead of 'I would like patients with X':" | |
| ), | |
| submit_btn="Generate Cohort", | |
| elem_id="description-input", | |
| placeholder="Enter a cohort description to begin...", | |
| ) | |
| with gr.Column(scale=1, min_width=150): | |
| case_counter = gr.Text( | |
| show_label=False, | |
| interactive=False, | |
| container=False, | |
| elem_id="case-counter", | |
| min_width=150, | |
| ) | |
| case_download = gr.Button( | |
| value="Export to GDC", | |
| min_width=150, | |
| elem_id="download-btn", | |
| ) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1, min_width=250): | |
| gr.Examples( | |
| examples=EXAMPLE_INPUTS, | |
| inputs=text_input, | |
| ) | |
| with gr.Column(scale=4): | |
| json_output = gr.Code( | |
| label="Cohort Filter JSON", | |
| value=json.dumps({"op": "and", "content": []}, indent=4), | |
| language="json", | |
| interactive=False, | |
| show_label=True, | |
| container=True, | |
| elem_id="json-output", | |
| ) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1, min_width=250): | |
| gr.Markdown("## Currently Selected Filters") | |
| with gr.Column(scale=4): | |
| active_selections = gr.CheckboxGroup( | |
| choices=[], | |
| show_label=False, | |
| interactive=True, | |
| elem_id="active-selections", | |
| ) | |
| with gr.Row(): | |
| gr.Markdown( | |
| "The generated cohort filter will autopopulate into the filter cards below. " | |
| "**GDC Cohort Copilot can make mistakes!** " | |
| "Refine your search using the interactive checkboxes. " | |
| "Note that many other options can be found by selecting the different tabs on the left." | |
| ) | |
| with gr.Row(): | |
| # Tab selectors | |
| tab_buttons = [] | |
| with gr.Column(scale=1, min_width=250): | |
| for name in TAB_NAMES: | |
| tab_button = gr.Button( | |
| value=name, | |
| variant="primary" if name == TAB_NAMES[0] else "secondary", | |
| ) | |
| tab_buttons.append(tab_button) | |
| # Filter cards | |
| tab_containers = [] | |
| filter_cards = [] | |
| for tab in CONFIG["tabs"]: | |
| visible = tab["name"] == TAB_NAMES[0] # default first card | |
| with gr.Column(scale=4, visible=visible) as tab_container: | |
| tab_containers.append(tab_container) | |
| with gr.Row(elem_classes=["card-group"]): | |
| for card in tab["cards"]: | |
| if isinstance(card["values"], list): | |
| filter_card = gr.CheckboxGroup( | |
| choices=[], | |
| label=card["name"], | |
| interactive=True, | |
| elem_classes=["filter-card"], | |
| ) | |
| else: | |
| # values is a dictionary and defines some meta options | |
| metaopts = card["values"] | |
| assert ( | |
| "type" in metaopts | |
| and metaopts["type"] == "range" | |
| and all( | |
| k in metaopts | |
| for k in [ | |
| "min", | |
| "max", | |
| ] | |
| ) | |
| ), f"Unknown meta options for {card['name']}" | |
| info = "Inclusive range" | |
| if "unit" in metaopts: | |
| info += f", units in {metaopts['unit']}" | |
| filter_card = RangeSlider( | |
| label=card["name"], | |
| info=info, | |
| minimum=metaopts["min"], | |
| maximum=metaopts["max"], | |
| step=1, # assume integer | |
| elem_classes=["filter-card", "filter-range"], | |
| ) | |
| filter_cards.append(filter_card) | |
| # Assign tab buttons to toggle visibility | |
| for tab_button, name in zip(tab_buttons, TAB_NAMES): | |
| tab_button.click( | |
| fn=set_active_tab, | |
| inputs=gr.State(name), | |
| outputs=tab_containers + tab_buttons, | |
| api_name=False, | |
| ) | |
| # Enable case download | |
| case_download.click( | |
| fn=None, # apparently this isn't the same as not specifying it | |
| js=DOWNLOAD_CASES_JS, | |
| inputs=json_output, | |
| api_name=False, | |
| ) | |
| # Load initial counts on startup | |
| demo.load( | |
| fn=update_cards_with_counts, | |
| inputs=[gr.State("")] + filter_cards, | |
| outputs=filter_cards + [case_counter], | |
| api_name=False, | |
| ) | |
| # Update checkboxes on filter generation | |
| # Also update JSON based on checkboxes | |
| # - relying on checkbox update to do this fires multiple times | |
| # - also propagates new model selections after json is updated | |
| # Also this way it shows the model generated JSON | |
| text_input.submit( | |
| fn=process_query, | |
| inputs=text_input, | |
| outputs=filter_cards + [json_output], | |
| api_name=False, | |
| ).success( | |
| fn=update_active_selections, | |
| inputs=filter_cards, | |
| outputs=[active_selections], | |
| api_name=False, | |
| ) | |
| # Update JSON based on cards | |
| # Keep user `input` event listener (vs `change`) otherwise will fire multiple times | |
| # Seems like otherwise it should be cyclical, Gradio must have some logic to prevent infinite loops | |
| for filter_card in filter_cards: | |
| if isinstance(filter_card, RangeSlider): | |
| filter_card.release( | |
| fn=update_json_from_cards, | |
| inputs=filter_cards, | |
| outputs=json_output, | |
| api_name=False, | |
| ).success( | |
| fn=update_active_selections, | |
| inputs=filter_cards, | |
| outputs=[active_selections], | |
| api_name=False, | |
| ) | |
| else: | |
| filter_card.input( | |
| fn=update_json_from_cards, | |
| inputs=filter_cards, | |
| outputs=json_output, | |
| api_name=False, | |
| ).success( | |
| fn=update_active_selections, | |
| inputs=filter_cards, | |
| outputs=[active_selections], | |
| api_name=False, | |
| ) | |
| # Enable functionality of the active filter selectors | |
| active_selections.input( | |
| fn=update_cards_from_active, | |
| inputs=[active_selections] + filter_cards, | |
| outputs=[active_selections] + filter_cards, | |
| api_name=False, | |
| ).success( | |
| fn=update_json_from_cards, | |
| inputs=filter_cards, | |
| outputs=json_output, | |
| api_name=False, | |
| ) | |
| # Update checkboxes after executing filter query | |
| json_output.change( | |
| fn=update_cards_with_counts, | |
| inputs=[json_output] + filter_cards, | |
| outputs=filter_cards + [case_counter], | |
| api_name=False, | |
| ) | |
| # gr.api(generate_filter, api_name="generate_filter") | |
| if __name__ == "__main__": | |
| demo.launch(ssr_mode=False) | |