Spaces:
Running
on
Zero
Running
on
Zero
add user feedback system (#3)
Browse files- add user feedback system (c435b2e30ec6086adad543d3f083fe3b830b9e66)
- app.py +127 -35
- scheduler.py +136 -0
- style.css +8 -0
app.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
import json
|
| 2 |
import os
|
| 3 |
from collections import defaultdict
|
|
|
|
|
|
|
| 4 |
|
| 5 |
import gradio as gr
|
| 6 |
import requests
|
|
@@ -13,6 +15,7 @@ from guidance.models import Transformers
|
|
| 13 |
from transformers import AutoTokenizer, GPT2LMHeadModel, set_seed
|
| 14 |
|
| 15 |
from schema import GDCCohortSchema # isort: skip
|
|
|
|
| 16 |
|
| 17 |
EXAMPLE_INPUTS = [
|
| 18 |
"bam files for TCGA-BRCA",
|
|
@@ -23,7 +26,8 @@ EXAMPLE_INPUTS = [
|
|
| 23 |
GDC_CASES_API_ENDPOINT = "https://api.gdc.cancer.gov/cases"
|
| 24 |
MODEL_NAME = "uc-ctds/gdc-cohort-llm-gpt2-s1M"
|
| 25 |
TOKENIZER_NAME = MODEL_NAME
|
| 26 |
-
|
|
|
|
| 27 |
|
| 28 |
with open("config.yaml", "r") as f:
|
| 29 |
CONFIG = yaml.safe_load(f)
|
|
@@ -45,9 +49,23 @@ FACETS_STR = ",".join(
|
|
| 45 |
]
|
| 46 |
)
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
-
tok = AutoTokenizer.from_pretrained(TOKENIZER_NAME, token=
|
| 50 |
-
model = GPT2LMHeadModel.from_pretrained(MODEL_NAME, token=
|
| 51 |
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
|
| 52 |
model = model.eval()
|
| 53 |
|
|
@@ -81,11 +99,11 @@ def generate_filter(query: str) -> str:
|
|
| 81 |
return cohort_filter
|
| 82 |
|
| 83 |
|
| 84 |
-
def _prepare_value_count(value, count):
|
| 85 |
return f"{value} [{count}]"
|
| 86 |
|
| 87 |
|
| 88 |
-
def _get_base_value(value_count):
|
| 89 |
value = value_count
|
| 90 |
if " [" in value:
|
| 91 |
value = value[: value.rfind(" [")]
|
|
@@ -183,7 +201,7 @@ def _convert_cohort_filter_to_active_selections(cohort_filter: str) -> list[str]
|
|
| 183 |
active_choices.append(f"{card_name.upper()}: {value}")
|
| 184 |
elif isinstance(default_values, dict):
|
| 185 |
# range-slider, maybe other options in the future?
|
| 186 |
-
assert default_values["type"] == "range", f"Expected range slider for card {card_name}"
|
| 187 |
assert isinstance(values, int), "values should be integer for range op"
|
| 188 |
if ">=" in field:
|
| 189 |
if values != default_values["min"]:
|
|
@@ -281,9 +299,13 @@ def _convert_cohort_filter_to_cards(cohort_filter: str, api_data: dict) -> list[
|
|
| 281 |
return card_updates
|
| 282 |
|
| 283 |
|
| 284 |
-
def update_elements_from_filtered_api_call(cohort_filter: str):
|
| 285 |
# return updates for:
|
| 286 |
-
# counter (text)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
|
| 288 |
# --- Execute API Call ---
|
| 289 |
patched_cohort_filter = _patch_range_filters_for_facet_endpoint(cohort_filter)
|
|
@@ -309,10 +331,12 @@ def update_elements_from_filtered_api_call(cohort_filter: str):
|
|
| 309 |
return [
|
| 310 |
gr.update(value=f"{case_count} Cases"), # case counter
|
| 311 |
gr.update(choices=active_choices, value=active_choices), # actives
|
|
|
|
|
|
|
| 312 |
] + card_updates
|
| 313 |
|
| 314 |
|
| 315 |
-
def update_json_from_cards(*selected_filters_per_card):
|
| 316 |
ops = []
|
| 317 |
for card_name, selected_filters in zip(CARD_NAMES, selected_filters_per_card):
|
| 318 |
# use the default values to determine card type (checkbox, range, etc)
|
|
@@ -368,7 +392,7 @@ def update_json_from_cards(*selected_filters_per_card):
|
|
| 368 |
return gr.update(value=filter_json)
|
| 369 |
|
| 370 |
|
| 371 |
-
def update_json_from_active(active_selections: list[str]):
|
| 372 |
grouped_selections = defaultdict(list)
|
| 373 |
for k_v in active_selections:
|
| 374 |
idx = k_v.find(": ")
|
|
@@ -431,11 +455,11 @@ def update_json_from_active(active_selections: list[str]):
|
|
| 431 |
return update_json_from_cards(*selected_filters_per_card)
|
| 432 |
|
| 433 |
|
| 434 |
-
def get_default_filter():
|
| 435 |
return json.dumps({"op": "and", "content": []}, indent=4)
|
| 436 |
|
| 437 |
|
| 438 |
-
def set_active_tab(selected_tab):
|
| 439 |
visibles = [gr.update(visible=(tab == selected_tab)) for tab in TAB_NAMES]
|
| 440 |
elem_classes = [
|
| 441 |
gr.update(variant="primary" if tab == selected_tab else "secondary")
|
|
@@ -444,6 +468,37 @@ def set_active_tab(selected_tab):
|
|
| 444 |
return visibles + elem_classes
|
| 445 |
|
| 446 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 447 |
DOWNLOAD_CASES_JS = f"""
|
| 448 |
function download_cases(filter_str) {{
|
| 449 |
const params = new URLSearchParams();
|
|
@@ -486,7 +541,7 @@ function download_cases(filter_str) {{
|
|
| 486 |
"""
|
| 487 |
|
| 488 |
with gr.Blocks(css_paths="style.css") as demo:
|
| 489 |
-
gr.Markdown("# GDC Cohort Copilot
|
| 490 |
|
| 491 |
with gr.Row(equal_height=True):
|
| 492 |
with gr.Column(scale=7):
|
|
@@ -517,21 +572,46 @@ with gr.Blocks(css_paths="style.css") as demo:
|
|
| 517 |
)
|
| 518 |
|
| 519 |
with gr.Row(equal_height=True):
|
| 520 |
-
with gr.Column(scale=
|
| 521 |
gr.Examples(
|
| 522 |
examples=EXAMPLE_INPUTS,
|
| 523 |
inputs=text_input,
|
| 524 |
)
|
| 525 |
-
with gr.Column(scale=
|
| 526 |
json_output = gr.Code(
|
| 527 |
label="Cohort Filter JSON",
|
| 528 |
-
# value=json.dumps({"op": "and", "content": []}, indent=4),
|
| 529 |
language="json",
|
| 530 |
interactive=False,
|
| 531 |
show_label=True,
|
| 532 |
container=True,
|
| 533 |
elem_id="json-output",
|
| 534 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 535 |
|
| 536 |
with gr.Row(equal_height=True):
|
| 537 |
with gr.Column(scale=1, min_width=250):
|
|
@@ -544,14 +624,6 @@ with gr.Blocks(css_paths="style.css") as demo:
|
|
| 544 |
elem_id="active-selections",
|
| 545 |
)
|
| 546 |
|
| 547 |
-
with gr.Row():
|
| 548 |
-
gr.Markdown(
|
| 549 |
-
"The generated cohort filter will autopopulate into the filter cards below. "
|
| 550 |
-
"**GDC Cohort Copilot can make mistakes!** "
|
| 551 |
-
"Refine your search using the interactive checkboxes. "
|
| 552 |
-
"Note that many other options can be found by selecting the different tabs on the left."
|
| 553 |
-
)
|
| 554 |
-
|
| 555 |
with gr.Row():
|
| 556 |
# Tab selectors
|
| 557 |
tab_buttons = []
|
|
@@ -613,7 +685,8 @@ with gr.Blocks(css_paths="style.css") as demo:
|
|
| 613 |
fn=set_active_tab,
|
| 614 |
inputs=gr.State(name),
|
| 615 |
outputs=tab_containers + tab_buttons,
|
| 616 |
-
api_name=False,
|
|
|
|
| 617 |
)
|
| 618 |
|
| 619 |
# Callback for case download button
|
|
@@ -621,11 +694,29 @@ with gr.Blocks(css_paths="style.css") as demo:
|
|
| 621 |
fn=None, # apparently this isn't the same as not specifying it, even though the default is None?
|
| 622 |
js=DOWNLOAD_CASES_JS, # need custom JSON to execute browser side download
|
| 623 |
inputs=json_output,
|
| 624 |
-
api_name=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 625 |
)
|
| 626 |
|
| 627 |
# Model generation should change the JSON filter
|
| 628 |
# All other element updates cascade
|
|
|
|
| 629 |
text_input.submit(
|
| 630 |
fn=generate_filter,
|
| 631 |
inputs=text_input,
|
|
@@ -640,16 +731,16 @@ with gr.Blocks(css_paths="style.css") as demo:
|
|
| 640 |
fn=update_json_from_cards,
|
| 641 |
inputs=filter_cards,
|
| 642 |
outputs=json_output,
|
| 643 |
-
# api_name=False,
|
| 644 |
-
show_api=False,
|
| 645 |
)
|
| 646 |
else:
|
| 647 |
filter_card.input(
|
| 648 |
fn=update_json_from_cards,
|
| 649 |
inputs=filter_cards,
|
| 650 |
outputs=json_output,
|
| 651 |
-
# api_name=False,
|
| 652 |
-
show_api=False,
|
| 653 |
)
|
| 654 |
|
| 655 |
# Changing the active selections should change the JSON filter
|
|
@@ -658,17 +749,17 @@ with gr.Blocks(css_paths="style.css") as demo:
|
|
| 658 |
fn=update_json_from_active,
|
| 659 |
inputs=active_selections,
|
| 660 |
outputs=json_output,
|
| 661 |
-
# api_name=False,
|
| 662 |
-
show_api=False,
|
| 663 |
)
|
| 664 |
|
| 665 |
# JSON filter change executes API call and updates all elements
|
| 666 |
json_output.change(
|
| 667 |
fn=update_elements_from_filtered_api_call,
|
| 668 |
inputs=json_output,
|
| 669 |
-
outputs=[case_counter, active_selections] + filter_cards,
|
| 670 |
-
# api_name=False,
|
| 671 |
-
show_api=False,
|
| 672 |
)
|
| 673 |
|
| 674 |
# Trigger initial update
|
|
@@ -678,6 +769,7 @@ with gr.Blocks(css_paths="style.css") as demo:
|
|
| 678 |
outputs=json_output,
|
| 679 |
# api_name=False, # this breaks the API functionality, not sure why
|
| 680 |
show_api=False, # so just hide the API endpoints instead, not ideal
|
|
|
|
| 681 |
)
|
| 682 |
|
| 683 |
if __name__ == "__main__":
|
|
|
|
| 1 |
import json
|
| 2 |
import os
|
| 3 |
from collections import defaultdict
|
| 4 |
+
from datetime import datetime, timezone
|
| 5 |
+
from pathlib import Path
|
| 6 |
|
| 7 |
import gradio as gr
|
| 8 |
import requests
|
|
|
|
| 15 |
from transformers import AutoTokenizer, GPT2LMHeadModel, set_seed
|
| 16 |
|
| 17 |
from schema import GDCCohortSchema # isort: skip
|
| 18 |
+
from scheduler import ParquetScheduler # isort: skip
|
| 19 |
|
| 20 |
EXAMPLE_INPUTS = [
|
| 21 |
"bam files for TCGA-BRCA",
|
|
|
|
| 26 |
GDC_CASES_API_ENDPOINT = "https://api.gdc.cancer.gov/cases"
|
| 27 |
MODEL_NAME = "uc-ctds/gdc-cohort-llm-gpt2-s1M"
|
| 28 |
TOKENIZER_NAME = MODEL_NAME
|
| 29 |
+
MODEL_READ_TOKEN = os.environ.get("MODEL_READ_TOKEN", None)
|
| 30 |
+
DATASET_WRITE_TOKEN = os.environ.get("DATASET_WRITE_TOKEN", None)
|
| 31 |
|
| 32 |
with open("config.yaml", "r") as f:
|
| 33 |
CONFIG = yaml.safe_load(f)
|
|
|
|
| 49 |
]
|
| 50 |
)
|
| 51 |
|
| 52 |
+
PREF_DS = os.environ.get("PREF_DS", False)
|
| 53 |
+
if PREF_DS:
|
| 54 |
+
assert DATASET_WRITE_TOKEN is not None
|
| 55 |
+
scheduler = ParquetScheduler(
|
| 56 |
+
repo_id=PREF_DS,
|
| 57 |
+
token=DATASET_WRITE_TOKEN,
|
| 58 |
+
schema={
|
| 59 |
+
"prompt": {"_type": "Value", "dtype": "string"},
|
| 60 |
+
"cohort_filter": {"_type": "Value", "dtype": "string"},
|
| 61 |
+
"preference": {"_type": "Value", "dtype": "bool"},
|
| 62 |
+
"timestamp": {"_type": "Value", "dtype": "string"},
|
| 63 |
+
},
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
|
| 67 |
+
tok = AutoTokenizer.from_pretrained(TOKENIZER_NAME, token=MODEL_READ_TOKEN)
|
| 68 |
+
model = GPT2LMHeadModel.from_pretrained(MODEL_NAME, token=MODEL_READ_TOKEN)
|
| 69 |
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
|
| 70 |
model = model.eval()
|
| 71 |
|
|
|
|
| 99 |
return cohort_filter
|
| 100 |
|
| 101 |
|
| 102 |
+
def _prepare_value_count(value: str, count: int) -> str:
|
| 103 |
return f"{value} [{count}]"
|
| 104 |
|
| 105 |
|
| 106 |
+
def _get_base_value(value_count: str) -> str:
|
| 107 |
value = value_count
|
| 108 |
if " [" in value:
|
| 109 |
value = value[: value.rfind(" [")]
|
|
|
|
| 201 |
active_choices.append(f"{card_name.upper()}: {value}")
|
| 202 |
elif isinstance(default_values, dict):
|
| 203 |
# range-slider, maybe other options in the future?
|
| 204 |
+
assert default_values["type"] == "range", f"Expected range slider for card {card_name}" # fmt: skip
|
| 205 |
assert isinstance(values, int), "values should be integer for range op"
|
| 206 |
if ">=" in field:
|
| 207 |
if values != default_values["min"]:
|
|
|
|
| 299 |
return card_updates
|
| 300 |
|
| 301 |
|
| 302 |
+
def update_elements_from_filtered_api_call(cohort_filter: str) -> list[dict]:
|
| 303 |
# return updates for:
|
| 304 |
+
# - counter (text)
|
| 305 |
+
# - active selections (checkbox group)
|
| 306 |
+
# - upvote (enable button, reset text)
|
| 307 |
+
# - downvote (enable button, reset text)
|
| 308 |
+
# - cards (list of checkbox group)
|
| 309 |
|
| 310 |
# --- Execute API Call ---
|
| 311 |
patched_cohort_filter = _patch_range_filters_for_facet_endpoint(cohort_filter)
|
|
|
|
| 331 |
return [
|
| 332 |
gr.update(value=f"{case_count} Cases"), # case counter
|
| 333 |
gr.update(choices=active_choices, value=active_choices), # actives
|
| 334 |
+
gr.update(interactive=True, value="⬆"),
|
| 335 |
+
gr.update(interactive=True, value="⬇"),
|
| 336 |
] + card_updates
|
| 337 |
|
| 338 |
|
| 339 |
+
def update_json_from_cards(*selected_filters_per_card: tuple[str]) -> str:
|
| 340 |
ops = []
|
| 341 |
for card_name, selected_filters in zip(CARD_NAMES, selected_filters_per_card):
|
| 342 |
# use the default values to determine card type (checkbox, range, etc)
|
|
|
|
| 392 |
return gr.update(value=filter_json)
|
| 393 |
|
| 394 |
|
| 395 |
+
def update_json_from_active(active_selections: list[str]) -> str:
|
| 396 |
grouped_selections = defaultdict(list)
|
| 397 |
for k_v in active_selections:
|
| 398 |
idx = k_v.find(": ")
|
|
|
|
| 455 |
return update_json_from_cards(*selected_filters_per_card)
|
| 456 |
|
| 457 |
|
| 458 |
+
def get_default_filter() -> str:
|
| 459 |
return json.dumps({"op": "and", "content": []}, indent=4)
|
| 460 |
|
| 461 |
|
| 462 |
+
def set_active_tab(selected_tab: str) -> list[dict]:
|
| 463 |
visibles = [gr.update(visible=(tab == selected_tab)) for tab in TAB_NAMES]
|
| 464 |
elem_classes = [
|
| 465 |
gr.update(variant="primary" if tab == selected_tab else "secondary")
|
|
|
|
| 468 |
return visibles + elem_classes
|
| 469 |
|
| 470 |
|
| 471 |
+
def save_user_preference(cohort_query: str, cohort_filter: str, preference: bool) -> list[dict]: # fmt: skip
|
| 472 |
+
timestamp = datetime.now(timezone.utc).isoformat()
|
| 473 |
+
data = {
|
| 474 |
+
"prompt": cohort_query,
|
| 475 |
+
"cohort_filter": json.dumps(json.loads(cohort_filter)), # remove whitespace
|
| 476 |
+
"preference": preference,
|
| 477 |
+
"timestamp": timestamp,
|
| 478 |
+
}
|
| 479 |
+
if PREF_DS:
|
| 480 |
+
scheduler.append(data)
|
| 481 |
+
print(f"Logged user preference data at {timestamp}")
|
| 482 |
+
else:
|
| 483 |
+
print(
|
| 484 |
+
f"No preference dataset configured, "
|
| 485 |
+
f"set PREF_DS env var to point to a HuggingFace Dataset Repo. "
|
| 486 |
+
f"Would have logged {data}"
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
# disable buttons
|
| 490 |
+
if preference:
|
| 491 |
+
upval = "✓"
|
| 492 |
+
downval = "--" # whitespace seems to be escaped by gradio
|
| 493 |
+
else:
|
| 494 |
+
upval = "--" # whitespace seems to be escaped by gradio
|
| 495 |
+
downval = "✗"
|
| 496 |
+
return [
|
| 497 |
+
gr.update(interactive=False, value=upval),
|
| 498 |
+
gr.update(interactive=False, value=downval),
|
| 499 |
+
]
|
| 500 |
+
|
| 501 |
+
|
| 502 |
DOWNLOAD_CASES_JS = f"""
|
| 503 |
function download_cases(filter_str) {{
|
| 504 |
const params = new URLSearchParams();
|
|
|
|
| 541 |
"""
|
| 542 |
|
| 543 |
with gr.Blocks(css_paths="style.css") as demo:
|
| 544 |
+
gr.Markdown("# GDC Cohort Copilot")
|
| 545 |
|
| 546 |
with gr.Row(equal_height=True):
|
| 547 |
with gr.Column(scale=7):
|
|
|
|
| 572 |
)
|
| 573 |
|
| 574 |
with gr.Row(equal_height=True):
|
| 575 |
+
with gr.Column(scale=2, min_width=250):
|
| 576 |
gr.Examples(
|
| 577 |
examples=EXAMPLE_INPUTS,
|
| 578 |
inputs=text_input,
|
| 579 |
)
|
| 580 |
+
with gr.Column(scale=7):
|
| 581 |
json_output = gr.Code(
|
| 582 |
label="Cohort Filter JSON",
|
|
|
|
| 583 |
language="json",
|
| 584 |
interactive=False,
|
| 585 |
show_label=True,
|
| 586 |
container=True,
|
| 587 |
elem_id="json-output",
|
| 588 |
)
|
| 589 |
+
with gr.Column(scale=1, min_width=50):
|
| 590 |
+
gr.Markdown(
|
| 591 |
+
"Is this correct?",
|
| 592 |
+
elem_id="vote-label",
|
| 593 |
+
)
|
| 594 |
+
upvote = gr.Button(
|
| 595 |
+
value="⬆",
|
| 596 |
+
min_width=50,
|
| 597 |
+
elem_id="upvote-btn",
|
| 598 |
+
)
|
| 599 |
+
downvote = gr.Button(
|
| 600 |
+
value="⬇",
|
| 601 |
+
min_width=50,
|
| 602 |
+
elem_id="download-btn",
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
with gr.Row():
|
| 606 |
+
gr.Markdown(
|
| 607 |
+
"The generated cohort filter will autopopulate into the filter cards below. "
|
| 608 |
+
"**<u>GDC Cohort Copilot can make mistakes!</u>** "
|
| 609 |
+
"Refine your search using the interactive checkboxes. "
|
| 610 |
+
"Note that many other options can be found by selecting the different tabs. "
|
| 611 |
+
"**<u>If you'd like to help us improve our model</u>**, you can use the up or down vote button to send us feedback. "
|
| 612 |
+
"We'll only save the current free text description, the cohort filter JSON, and your vote. "
|
| 613 |
+
"You can also show us what the right filter should have been by manually refining it using the checkboxes, before up voting."
|
| 614 |
+
)
|
| 615 |
|
| 616 |
with gr.Row(equal_height=True):
|
| 617 |
with gr.Column(scale=1, min_width=250):
|
|
|
|
| 624 |
elem_id="active-selections",
|
| 625 |
)
|
| 626 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 627 |
with gr.Row():
|
| 628 |
# Tab selectors
|
| 629 |
tab_buttons = []
|
|
|
|
| 685 |
fn=set_active_tab,
|
| 686 |
inputs=gr.State(name),
|
| 687 |
outputs=tab_containers + tab_buttons,
|
| 688 |
+
# api_name=False,
|
| 689 |
+
show_api=False,
|
| 690 |
)
|
| 691 |
|
| 692 |
# Callback for case download button
|
|
|
|
| 694 |
fn=None, # apparently this isn't the same as not specifying it, even though the default is None?
|
| 695 |
js=DOWNLOAD_CASES_JS, # need custom JSON to execute browser side download
|
| 696 |
inputs=json_output,
|
| 697 |
+
# api_name=False,
|
| 698 |
+
show_api=False,
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
# Enable user preference logging
|
| 702 |
+
upvote.click(
|
| 703 |
+
fn=save_user_preference,
|
| 704 |
+
inputs=[text_input, json_output, gr.State(True)],
|
| 705 |
+
outputs=[upvote, downvote],
|
| 706 |
+
# api_name=False,
|
| 707 |
+
show_api=False,
|
| 708 |
+
)
|
| 709 |
+
downvote.click(
|
| 710 |
+
fn=save_user_preference,
|
| 711 |
+
inputs=[text_input, json_output, gr.State(False)],
|
| 712 |
+
outputs=[upvote, downvote],
|
| 713 |
+
# api_name=False,
|
| 714 |
+
show_api=False,
|
| 715 |
)
|
| 716 |
|
| 717 |
# Model generation should change the JSON filter
|
| 718 |
# All other element updates cascade
|
| 719 |
+
# This is the only API that should be exposed
|
| 720 |
text_input.submit(
|
| 721 |
fn=generate_filter,
|
| 722 |
inputs=text_input,
|
|
|
|
| 731 |
fn=update_json_from_cards,
|
| 732 |
inputs=filter_cards,
|
| 733 |
outputs=json_output,
|
| 734 |
+
# api_name=False,
|
| 735 |
+
show_api=False,
|
| 736 |
)
|
| 737 |
else:
|
| 738 |
filter_card.input(
|
| 739 |
fn=update_json_from_cards,
|
| 740 |
inputs=filter_cards,
|
| 741 |
outputs=json_output,
|
| 742 |
+
# api_name=False,
|
| 743 |
+
show_api=False,
|
| 744 |
)
|
| 745 |
|
| 746 |
# Changing the active selections should change the JSON filter
|
|
|
|
| 749 |
fn=update_json_from_active,
|
| 750 |
inputs=active_selections,
|
| 751 |
outputs=json_output,
|
| 752 |
+
# api_name=False,
|
| 753 |
+
show_api=False,
|
| 754 |
)
|
| 755 |
|
| 756 |
# JSON filter change executes API call and updates all elements
|
| 757 |
json_output.change(
|
| 758 |
fn=update_elements_from_filtered_api_call,
|
| 759 |
inputs=json_output,
|
| 760 |
+
outputs=[case_counter, active_selections, upvote, downvote] + filter_cards,
|
| 761 |
+
# api_name=False,
|
| 762 |
+
show_api=False,
|
| 763 |
)
|
| 764 |
|
| 765 |
# Trigger initial update
|
|
|
|
| 769 |
outputs=json_output,
|
| 770 |
# api_name=False, # this breaks the API functionality, not sure why
|
| 771 |
show_api=False, # so just hide the API endpoints instead, not ideal
|
| 772 |
+
# the weirdness with the API toggle seems true for all disabled API endpoints
|
| 773 |
)
|
| 774 |
|
| 775 |
if __name__ == "__main__":
|
scheduler.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Taken from https://huggingface.co/spaces/Wauplin/space_to_dataset_saver
|
| 2 |
+
# which was from https://huggingface.co/spaces/hysts-samples/save-user-preferences
|
| 3 |
+
# Credits to @@hysts and @@Wauplin
|
| 4 |
+
import json
|
| 5 |
+
import tempfile
|
| 6 |
+
import uuid
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any, Dict, List, Optional, Union
|
| 9 |
+
|
| 10 |
+
import pyarrow as pa
|
| 11 |
+
import pyarrow.parquet as pq
|
| 12 |
+
from huggingface_hub import CommitScheduler
|
| 13 |
+
from huggingface_hub.hf_api import HfApi
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ParquetScheduler(CommitScheduler):
|
| 17 |
+
"""
|
| 18 |
+
Usage: configure the scheduler with a repo id.
|
| 19 |
+
Once started, you can add data to be uploaded to the Hub.
|
| 20 |
+
Each `.append` call will result in a new row in your final dataset.
|
| 21 |
+
The scheduler requires you manually set the schema (read [the docs](https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Value) for the list of possible values):
|
| 22 |
+
|
| 23 |
+
```py
|
| 24 |
+
# Start scheduler
|
| 25 |
+
>>> scheduler = ParquetScheduler(
|
| 26 |
+
... repo_id="my-org/my-dataset",
|
| 27 |
+
... schema={
|
| 28 |
+
... "prompt": {"_type": "Value", "dtype": "string"},
|
| 29 |
+
... "cohort_filter": {"_type": "Value", "dtype": "string"},
|
| 30 |
+
... "preference": {"_type": "Value", "dtype": "bool"},
|
| 31 |
+
... "timestamp": {"_type": "Value", "dtype": "string"},
|
| 32 |
+
... },
|
| 33 |
+
... )
|
| 34 |
+
|
| 35 |
+
# Append some data to be uploaded
|
| 36 |
+
>>> scheduler.append({...})
|
| 37 |
+
>>> scheduler.append({...})
|
| 38 |
+
>>> scheduler.append({...})
|
| 39 |
+
```
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
*,
|
| 45 |
+
repo_id: str,
|
| 46 |
+
schema: Optional[Dict[str, Dict[str, str]]] = None,
|
| 47 |
+
every: Union[int, float] = 5,
|
| 48 |
+
path_in_repo: Optional[str] = "data",
|
| 49 |
+
repo_type: Optional[str] = "dataset",
|
| 50 |
+
revision: Optional[str] = None,
|
| 51 |
+
private: bool = False,
|
| 52 |
+
token: Optional[str] = None,
|
| 53 |
+
allow_patterns: Union[List[str], str, None] = None,
|
| 54 |
+
ignore_patterns: Union[List[str], str, None] = None,
|
| 55 |
+
hf_api: Optional[HfApi] = None,
|
| 56 |
+
) -> None:
|
| 57 |
+
super().__init__(
|
| 58 |
+
repo_id=repo_id,
|
| 59 |
+
folder_path="dummy", # not used by the scheduler
|
| 60 |
+
every=every,
|
| 61 |
+
path_in_repo=path_in_repo,
|
| 62 |
+
repo_type=repo_type,
|
| 63 |
+
revision=revision,
|
| 64 |
+
private=private,
|
| 65 |
+
token=token,
|
| 66 |
+
allow_patterns=allow_patterns,
|
| 67 |
+
ignore_patterns=ignore_patterns,
|
| 68 |
+
hf_api=hf_api,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
self._rows: List[Dict[str, Any]] = []
|
| 72 |
+
self._schema = schema
|
| 73 |
+
|
| 74 |
+
def append(self, row: Dict[str, Any]) -> None:
|
| 75 |
+
"""Add a new item to be uploaded."""
|
| 76 |
+
with self.lock:
|
| 77 |
+
self._rows.append(row)
|
| 78 |
+
|
| 79 |
+
def push_to_hub(self):
|
| 80 |
+
# Check for new rows to push
|
| 81 |
+
with self.lock:
|
| 82 |
+
rows = self._rows
|
| 83 |
+
self._rows = []
|
| 84 |
+
if not rows:
|
| 85 |
+
return
|
| 86 |
+
print(f"Got {len(rows)} item(s) to commit.")
|
| 87 |
+
|
| 88 |
+
# Load images + create 'features' config for datasets library
|
| 89 |
+
schema: Dict[str, Dict] = self._schema or {}
|
| 90 |
+
path_to_cleanup: List[Path] = []
|
| 91 |
+
for row in rows:
|
| 92 |
+
for key, value in row.items():
|
| 93 |
+
|
| 94 |
+
# Load binary files if necessary
|
| 95 |
+
if schema[key]["_type"] in ("Image", "Audio"):
|
| 96 |
+
# It's an image or audio: we load the bytes and remember to cleanup the file
|
| 97 |
+
file_path = Path(value)
|
| 98 |
+
if file_path.is_file():
|
| 99 |
+
row[key] = {
|
| 100 |
+
"path": file_path.name,
|
| 101 |
+
"bytes": file_path.read_bytes(),
|
| 102 |
+
}
|
| 103 |
+
path_to_cleanup.append(file_path)
|
| 104 |
+
|
| 105 |
+
# Complete rows if needed
|
| 106 |
+
for row in rows:
|
| 107 |
+
for feature in schema:
|
| 108 |
+
if feature not in row:
|
| 109 |
+
row[feature] = None
|
| 110 |
+
|
| 111 |
+
# Export items to Arrow format
|
| 112 |
+
table = pa.Table.from_pylist(rows)
|
| 113 |
+
|
| 114 |
+
# Add metadata (used by datasets library)
|
| 115 |
+
table = table.replace_schema_metadata(
|
| 116 |
+
{"huggingface": json.dumps({"info": {"features": schema}})}
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Write to parquet file
|
| 120 |
+
archive_file = tempfile.NamedTemporaryFile()
|
| 121 |
+
pq.write_table(table, archive_file.name)
|
| 122 |
+
|
| 123 |
+
# Upload
|
| 124 |
+
self.api.upload_file(
|
| 125 |
+
repo_id=self.repo_id,
|
| 126 |
+
repo_type=self.repo_type,
|
| 127 |
+
revision=self.revision,
|
| 128 |
+
path_in_repo=f"{uuid.uuid4()}.parquet",
|
| 129 |
+
path_or_fileobj=archive_file.name,
|
| 130 |
+
)
|
| 131 |
+
print(f"Commit completed.")
|
| 132 |
+
|
| 133 |
+
# Cleanup
|
| 134 |
+
archive_file.close()
|
| 135 |
+
for path in path_to_cleanup:
|
| 136 |
+
path.unlink(missing_ok=True)
|
style.css
CHANGED
|
@@ -19,6 +19,14 @@
|
|
| 19 |
font-size: calc(var(--block-title-text-size) + 2px);
|
| 20 |
}
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
#json-output {
|
| 23 |
height: 96px !important;
|
| 24 |
}
|
|
|
|
| 19 |
font-size: calc(var(--block-title-text-size) + 2px);
|
| 20 |
}
|
| 21 |
|
| 22 |
+
#vote-label {
|
| 23 |
+
text-align: center;
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
#upvote-btn {
|
| 27 |
+
color: var(--button-primary-background-fill);
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
#json-output {
|
| 31 |
height: 96px !important;
|
| 32 |
}
|