Spaces:
Sleeping
Sleeping
File size: 29,117 Bytes
7db7222 9905bb9 fb56316 8dddd0a 7db7222 64da2e4 6894046 64da2e4 a91af1f 5369f18 13ce2ed 0194537 85f5622 8dddd0a 7db7222 6894046 8dddd0a 22e0374 e4ea637 8dddd0a 7db7222 4e30f9c 5c7c416 1b1f29e 589589c e365b79 36b6975 1b1f29e 36b6975 1b1f29e cb6810e 8dddd0a 22e0374 7db7222 22e0374 7db7222 22e0374 7db7222 22e0374 6894046 6c4a12e 47785b8 22e0374 55b9f90 22e0374 7db7222 22e0374 7db7222 22e0374 7db7222 22e0374 64da2e4 cfcda68 5c697cd cfcda68 5197cc3 47785b8 6894046 47785b8 a79977a 6894046 69f52c7 062ea00 69f52c7 de6ca97 69f52c7 3067368 279b4ba d010c95 6894046 3067368 6894046 69f52c7 5197cc3 968854a 69f52c7 5197cc3 69f52c7 5197cc3 6894046 3067368 6894046 5197cc3 69f52c7 5197cc3 cb6810e 7db7222 8dddd0a cb6810e ef893cf c1efbdb 7db7222 c1efbdb 6c4a12e 744db0f 988342a cb6810e 8dddd0a 55b9f90 45bb144 8dddd0a de6ca97 8dddd0a 988342a de6ca97 8dddd0a b374f9a 8dddd0a 5197cc3 8dddd0a 5197cc3 8dddd0a 5197cc3 8dddd0a 45bb144 8dddd0a 0c1f098 8dddd0a c1efbdb 9e9b506 64da2e4 45bb144 8dddd0a 7db7222 8dddd0a fb56316 8dddd0a a7bbe08 697de72 fb56316 7db7222 fb56316 5c697cd fb56316 a7bbe08 697de72 e4ea637 5c697cd fb56316 5c697cd fb56316 8dddd0a cb6810e 55b9f90 cb6810e 9e9b506 fb56316 81d7e97 8dddd0a 55b9f90 8dddd0a fb56316 8dddd0a fb56316 8dddd0a fb56316 8dddd0a fb56316 8dddd0a cb6810e 55b9f90 8dddd0a fb56316 8dddd0a fb56316 6894046 cb6810e fb56316 6894046 fb56316 eb16a43 fb56316 6894046 fb56316 8dddd0a cb6810e 8e21388 2432539 cb6810e 2432539 cb6810e 2432539 ff1765c 2432539 ff1765c 2432539 8e21388 2432539 333dd60 2432539 d9e6326 9bcdbf3 d9e6326 9bcdbf3 d9e6326 85f5622 d9e6326 ebfb349 f482c76 75c7aee 9bcdbf3 566141e 9bcdbf3 c6a7d09 11efff9 566141e 9bcdbf3 ebfb349 e077b87 c6a7d09 67d919c dfa4cfb a399e08 314e0d3 67d919c 8e37fce 49e8e69 f482c76 67d919c f482c76 cb6810e 8dddd0a 11efff9 7db7222 00d709d 8dddd0a b8bd586 6894046 7806d0b 6894046 7806d0b 6894046 8f3a0af 7806d0b 6894046 b8bd586 fb56316 b8bd586 fb56316 b8bd586 a828103 b8bd586 45bb144 b8bd586 29bf05f 67d919c dfa4cfb 67d919c 29bf05f b8bd586 d9e6326 edc6ff0 11efff9 f5e7c06 29bf05f 8dddd0a 64da2e4 b8bd586 7b1bb43 b8bd586 7b1bb43 b8bd586 7b1bb43 b8bd586 64da2e4 f07003c 64da2e4 7db7222 333dd60 64da2e4 333dd60 b8bd586 64da2e4 1b1f29e 7db7222 1b1f29e 29bf05f 64da2e4 65aa518 64da2e4 11efff9 29bf05f 8dddd0a 64da2e4 f5e7c06 64da2e4 8dddd0a 64da2e4 7db7222 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 |
import json
import os
import re
from types import SimpleNamespace
import gradio as gr
from itertools import chain
import pandas as pd
import numpy as np
import spaces
import spacy
import torch
import textwrap
from guidance import gen as guidance_gen
from guidance.models import Transformers
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BertForSequenceClassification,
BertTokenizer,
set_seed,
)
from sentence_transformers import SentenceTransformer, util
from methods import gdc_api_calls, utilities
# set up various tokens
hf_TOKEN = os.environ.get("hf_svc_ctds", False)
# disable tokenizer parallelism
os.environ["TOKENIZERS_PARALLELISM"] = "false"
EXAMPLE_INPUTS = [
"What is the co-occurence frequency of somatic homozygous deletions in CDKN2A and CDKN2B in the mesothelioma project TCGA-MESO in the genomic data commons?",
"What is the co-occurence frequency of somatic heterozygous deletions in BRCA2 and NF1 in the Kidney Chromophobe TCGA-KICH project in the genomic data commons?",
"What percentage of ovarian serous cystadenocarcinoma cases have a somatic heterozygous deletion in BRCA1 and simple somatic mutations in BRCA1 in the genomic data commons?",
"What fraction of cases have simple somatic mutations or copy number variants in ALK in Lung Adenocarcinoma TCGA-LUAD project in the genomic data commons?",
"How often is microsatellite instability observed in Colon Adenocarcinoma TCGA-COAD project in the genomic data commons?",
"How often is the BRAF V600E mutation found in Skin Cutaneous Melanoma TCGA-SKCM project in the genomic data commons?",
"What is the co-occurence frequency of IDH1 R132H and TP53 R273C simple somatic mutations in the low grade glioma project TCGA-LGG in the genomic data commons?",
"In Lung Adenocarcinoma TCGA-LUAD project data from the genomic data commons, what is the frequency of ALK amplification?"
]
EXAMPLE_LABELS = [
"combination homozygous deletions",
"combination heterozygous deletions",
"heterozygous deletion and somatic mutations",
"copy number variants or somatic mutations",
"microsatellite-instability",
"simple somatic mutation",
"combination somatic mutations",
"single gene amplification"
]
# for natural language gene and intent descriptions
intent_expansion = {
'cnv_and_ssm': 'copy number variants or simple somatic mutations',
'freq_cnv_loss_or_gain': 'copy number variant losses or gains',
'msi_h_frequency': 'microsatellite instability',
'freq_cnv_loss_or_gain_comb': 'copy number variant losses or gains',
'ssm_frequency': 'simple somatic mutations',
'top_cases_counts_by_gene': 'copy number variants or simple somatic mutations'
}
# set up requirements: models and data
print("getting gdc project information")
project_mappings = gdc_api_calls.get_gdc_project_ids(start=0, stop=86)
print("loading intent model and tokenizer")
model_id = "uc-ctds/query_intent"
intent_tok = AutoTokenizer.from_pretrained(
model_id, trust_remote_code=True, token=hf_TOKEN
)
intent_model = BertForSequenceClassification.from_pretrained(model_id, token=hf_TOKEN)
intent_model = intent_model.to("cuda").eval()
# load sentence transformer model to test cancer embeddings
print('loading sentence transformer model')
st_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
st_model = st_model.to("cuda").eval()
print("loading gdc genes and mutations")
gdc_genes_mutations = utilities.load_gdc_genes_mutations_hf(hf_TOKEN)
print("loading llama-3B model and tokenizer")
model_id = "meta-llama/Llama-3.2-3B-Instruct"
tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, token=hf_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
model_id, torch_dtype=torch.float16, trust_remote_code=True, token=hf_TOKEN
)
model = model.to("cuda").eval()
# global init to test guidance speed up
base_lm = Transformers(model=model, tokenizer=tok)
@utilities.timeit
def infer_mutation_entities(gene_entities, query):
mutation_entities = []
for g in gene_entities:
for m in gdc_genes_mutations[g]:
if m in query:
mutation_entities.append(m)
return mutation_entities
@utilities.timeit
def infer_gene_entities_from_query(query):
entities = []
# gene recognition with simple dict-based method
for g in gdc_genes_mutations.keys():
if (g in query) and (g in query.split(" ")):
entities.append(g)
return entities
@spaces.GPU(duration=15)
def get_project_embeddings():
project_rows = []
for k,v in project_mappings.items():
new_v = [item.replace(',', '') for item in v]
combined = ','.join([k] + new_v)
project_rows.append(combined)
row_embeddings = st_model.encode(project_rows, convert_to_tensor=True, device='cuda')
return project_rows, row_embeddings.cpu().numpy()
def check_if_project_id_in_query(query):
# check if mention of project keys
# e.g. TCGA-BRCA in query
project_list = project_mappings.keys()
cancer_entities = [
potential_ce
for potential_ce in query.split(" ")
if potential_ce in project_list
]
return cancer_entities
def proj_id_and_partial_match(query, initial_cancer_entities):
final_entities = []
if initial_cancer_entities:
# print('checking for full match between initial cancer entities and GDC project descriptions')
# check for match with project_mapping values
# e.g. match "ovarian serous cystadenocarcinoma" to TCGA-OV project
for ic in initial_cancer_entities:
for k, v in project_mappings.items():
for c in v:
if ic in c.lower():
final_entities.append(k)
else:
# print('no initial cancer entities, check for full match between query terms and GDC project descriptions')
for term in query.lower().split(" "):
for k, v in project_mappings.items():
for c in v:
if term in c.lower():
final_entities.append(k)
return list(set(final_entities))
@spaces.GPU(duration=15)
def get_top_k_scores(query, row_embeddings, top_k=20):
query_embedding = st_model.encode(query, convert_to_tensor=True, device='cuda')
row_embeddings = torch.from_numpy(row_embeddings).float().to('cuda')
cosine_scores = util.cos_sim(query_embedding, row_embeddings)
top_results = torch.topk(cosine_scores, k=top_k)
# convert to CPU and return
top_results_scores = top_results.values.cpu().tolist()
top_results_indices = top_results.indices.cpu().tolist()
return top_results_scores, top_results_indices
def get_top_k_cancer_entities(project_rows, top_results_scores, top_results_indices):
top_cancer_entities = []
for idx, score in enumerate(top_results_scores[0]):
if score > 0.5:
row_idx = top_results_indices[0][idx]
print('best row, score: {} {}'.format(project_rows[row_idx], score))
top_cancer_entities.append([project_rows[row_idx], score])
try:
top_projects = [sublist[0].split(',')[0] for sublist in top_cancer_entities]
except Exception as e:
top_projects = []
return top_projects
@utilities.timeit
def postprocess_cancer_entities(initial_cancer_entities, query):
# print('initial cancer entities {}'.format(initial_cancer_entities))
# get project embeddings
print('loading cancer embeddings')
project_rows, row_embeddings = get_project_embeddings()
final_entities = check_if_project_id_in_query(query)
if final_entities:
return final_entities
else:
if initial_cancer_entities:
# first query GDC projects endpt
# print('test 1 (w/ initial entities): querying GDC projects endpt for project_id')
gdc_project_match = gdc_api_calls.map_cancer_entities_to_project(
initial_cancer_entities, project_mappings
)
# print('mapped projects to ids {}'.format(gdc_project_match))
if gdc_project_match.values():
final_entities = list(gdc_project_match.values())
if not final_entities:
# print('test 2 (w/ initial entities): no result from GDC projects endpt, check for matches '
# 'between query terms and gdc project_mappings')
final_entities = proj_id_and_partial_match(
query, initial_cancer_entities
)
# try embedding based match
if not final_entities:
print('Test embedding based match')
for i in initial_cancer_entities:
top_results_scores, top_results_indices = get_top_k_scores(i, row_embeddings)
c_entities = get_top_k_cancer_entities(project_rows, top_results_scores, top_results_indices)
final_entities.append(c_entities)
final_entities = list(chain.from_iterable(final_entities))
else:
# no initial_cancer_entities
# check project_mappings keys/values for matches with query terms
# print('test 3 (w/o initial entities): no result from GDC projects endpt, check for matches '
# 'between query terms and gdc project_mappings')
final_entities = proj_id_and_partial_match(
query, initial_cancer_entities
)
return final_entities
@utilities.timeit
def execute_api_call(intent, gene_entities, mutation_entities, cancer_entities, query):
if intent == "ssm_frequency":
result, cancer_entities = utilities.get_ssm_frequency(
gene_entities, mutation_entities, cancer_entities, project_mappings
)
elif intent == "top_mutated_genes_by_project":
result = gdc_api_calls.get_top_mutated_genes_by_project(
cancer_entities, top_k=10
)
elif intent == "most_frequently_mutated_gene":
result = gdc_api_calls.get_top_mutated_genes_by_project(
cancer_entities, top_k=1
)
elif intent == "freq_cnv_loss_or_gain":
result, cancer_entities = gdc_api_calls.get_freq_cnv_loss_or_gain(
gene_entities, cancer_entities, query, cnv_and_ssm_flag=False
)
elif intent == "msi_h_frequency":
result, cancer_entities = gdc_api_calls.get_msi_frequency(cancer_entities)
elif intent == "cnv_and_ssm":
result, cancer_entities = utilities.get_freq_of_cnv_and_ssms(
query, cancer_entities, gene_entities, gdc_genes_mutations
)
elif intent == "top_cases_counts_by_gene":
result, cancer_entities = gdc_api_calls.get_top_cases_counts_by_gene(
gene_entities, cancer_entities
)
elif intent == "project_summary":
result = gdc_api_calls.get_project_summary(cancer_entities)
else:
result = "user intent not recognized, or use case not covered"
return result, cancer_entities
def construct_modified_query_base_llm(query):
prompt_template = "Only use results from the genomic data commons in your response and provide frequencies as a percentage. Only report the final response."
modified_query = query + prompt_template
return modified_query
def construct_modified_query_percentage(query, gdc_result):
# pass the api results as a prompt to the query
prompt_template = (
" Only report the final response. Ignore all prior knowledge. You must only respond with the following percentage frequencies in your response, no other response is allowed: \n"
+ gdc_result
+ "\n"
)
modified_query = query + prompt_template
return modified_query
def construct_modified_query_description(genes, intent):
modified_query = f'Provide a one line general description about {intent} in genes {genes} in cancer.'
return modified_query
@spaces.GPU(duration=10)
def infer_user_intent(query):
intent_labels = {
"ssm_frequency": 0.0,
"msi_h_frequency": 1.0,
"freq_cnv_loss_or_gain": 2.0,
"top_cases_counts_by_gene": 3.0,
"cnv_and_ssm": 4.0,
}
inputs = intent_tok(query, return_tensors="pt", truncation=True, padding=True)
inputs = {k: v.to("cuda") for k, v in inputs.items()}
outputs = intent_model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=1)
predicted_label = torch.argmax(probs, dim=1).item()
for k, v in intent_labels.items():
if v == predicted_label:
return k
@utilities.timeit
# initial guesses for cancer entities
def return_initial_cancer_entities(query, model):
nlp = spacy.load(model)
doc = nlp(query)
result = doc.ents
initial_cancer_entities = [e.text for e in result if e.label_ == "DISEASE"]
return initial_cancer_entities
@utilities.timeit
# function to combine entities, intent and API call
def construct_and_execute_api_call(query):
print(
"\nStep 1: Starting GDC-QAG on input natural language query:\n{}\n".format(
query
)
)
# Infer entities
initial_cancer_entities = check_if_project_id_in_query(query)
if not initial_cancer_entities:
try:
initial_cancer_entities = return_initial_cancer_entities(
query, model="en_ner_bc5cdr_md"
)
print('initial cancer entities {}'.format(initial_cancer_entities))
except Exception as e:
print("unable to guess cancer entities {}".format(str(e)))
initial_cancer_entities = []
cancer_entities = postprocess_cancer_entities(
initial_cancer_entities=initial_cancer_entities, query=query
)
# if cancer entities is empty from above methods return all projects
if not cancer_entities:
cancer_entities = list(project_mappings.keys())
gene_entities = infer_gene_entities_from_query(query)
mutation_entities = infer_mutation_entities(
gene_entities=gene_entities,
query=query
)
print("\nStep 2: Entity Extraction\n")
print("gene entities {}".format(gene_entities))
print("mutation entities {}".format(mutation_entities))
print("cancer entities {}".format(cancer_entities))
# infer user intent
intent = infer_user_intent(query)
print("\nStep 3: Intent Inference:\n{}\n".format(intent))
try:
print("\nStep 4: API call builder for intent {}\n".format(intent))
api_call_result, cancer_entities = execute_api_call(
intent, gene_entities, mutation_entities, cancer_entities, query
)
except Exception as e:
print("unable to process query {} {}".format(query, str(e)))
api_call_result = []
cancer_entities = []
return SimpleNamespace(
gdc_result=api_call_result,
cancer_entities=cancer_entities,
intent=intent,
gene_entities=gene_entities,
mutation_entities=mutation_entities,
)
# generate llama model response
@utilities.timeit
@spaces.GPU(duration=20)
def generate_percentage_response(modified_query):
# set_seed(1042)
regex = "The final response is: \d*\.\d*%"
lm = base_lm
lm += modified_query
lm += guidance_gen("pct_response", n=1, temperature=0, max_tokens=40, regex=regex)
return lm["pct_response"]
# generate llama model descriptive response
@utilities.timeit
@spaces.GPU(duration=20)
def generate_descriptive_response(modified_query):
lm = base_lm
lm += modified_query
lm += guidance_gen(
"desc_response",
n=1,
temperature=0,
max_tokens=100,
regex="^[^\\n]*[.\S+]$",
)
return lm["desc_response"]
@utilities.timeit
def batch_test(query):
modified_query = construct_modified_query_base_llm(query)
print(f"obtain baseline llama-3B response on modified query: {modified_query}")
llama_base_output = generate_percentage_response(modified_query)
print(f"llama-3B baseline response: {llama_base_output}")
try:
result = construct_and_execute_api_call(query)
except Exception as e:
# unable to compute at this time, recheck
result.gdc_result = []
result.cancer_entities = []
# if there is not a helper output for each unique cancer entity
# log error to inspect and reprocess query later
try:
len(result.gdc_result) == len(result.cancer_entities)
except Exception as e:
msg = "there is not a unique helper output for each unique \
cancer entity in {}".format(
query
)
print("exception {}".format(msg))
result.gdc_result = []
result.cancer_entities = []
return pd.Series(
[
llama_base_output,
result.gdc_result,
result.cancer_entities,
result.intent,
result.gene_entities,
result.mutation_entities,
]
)
@utilities.timeit
def get_prefinal_response(row):
try:
query = row["questions"]
gdc_result = row["gdc_result"]
except Exception as e:
print(f"unable to retrieve query: {query} or gdc_result: {gdc_result}")
print("\nStep 6: Construct LLM prompts (percentage) for llama-3B\n")
percentage_prompt = construct_modified_query_percentage(query, gdc_result)
print("\nStep 7: Generate LLM response R (percentage) on query augmented prompts\n")
percentage_response = generate_percentage_response(percentage_prompt)
percentage_response = re.sub(
r'final response', 'frequency for your query', percentage_response)
return pd.Series([
percentage_prompt, percentage_response
])
@utilities.timeit
def postprocess_llm_description(descriptive_response):
try:
num_tokens = len(tok.encode(descriptive_response))
if num_tokens < 100:
postprocessed_desc_response = descriptive_response
else:
response_list = re.split(r'\.(?!\d+%)', descriptive_response)
# remove empty elements
filtered_list = list(filter(None, response_list))
postprocessed_desc_response = '.'.join(filtered_list[:-1])
except Exception as e:
print('unable to postprocess LLM gene description {}'.format(
str(e)
))
postprocessed_desc_response = 'unable to postprocess LLM gene description'
if not postprocessed_desc_response.endswith('.'):
postprocessed_desc_response += '.'
return postprocessed_desc_response
@utilities.timeit
def postprocess_percentage_response(
gdc_qag_base_stat, gdc_result_percentage, gdc_qag_percentage_response):
try:
# check/confirm if gdc_qag_base_stat percentage == gdc_result_percentage
# change it, if not
if gdc_qag_base_stat != gdc_result_percentage:
gdc_qag_base_stat = gdc_result_percentage
final_gdc_qag_percentage_response = 'The frequency for your query is: {}%'.format(
gdc_qag_base_stat)
else:
final_gdc_qag_percentage_response = gdc_qag_percentage_response
except Exception as e:
print('unable to postprocess percentage frequency {}'.format(
str(e)
))
final_gdc_qag_percentage_response = 'unable to postprocess percentage frequency'
return gdc_qag_base_stat, final_gdc_qag_percentage_response
@utilities.timeit
def postprocess_response(row):
# three goals:
# goal 1:
# check/confirm the results in gdc-qag percentage response
# return a percentage response for gdc-qag
# goal 2:
# postprocess descriptive response
# goal 3:
# return concatenated final response from gdc_qag
# (descriptive response + percentage response)
pattern = r".*?(\d*\.\d*)%.*?"
###### various inputs ###############################
try:
# this is the result obtained in GDC-QAG via API
gdc_result = row["gdc_result"]
except Exception as e:
print('GDC Result not found in gdc_qag output, returning nan {}'.format(
str(e)
))
gdc_result = np.nan
try:
# extract gdc_result percentage from gdc_result
match = re.search(pattern, gdc_result)
if match:
gdc_result_percentage = float(match.group(1))
else:
gdc_result_percentage = np.nan
print('no data available in gdc')
except Exception as e:
print('unable to extract percentage from gdc result {}'.format(
str(e)))
gdc_result_percentage = np.nan
try:
# this is the LLM generated response with freq, after seeing gdc_result
gdc_qag_percentage_response = row['percentage_response']
except Exception as e:
print('LLM generated gdc_qag percentage response not found, returning nan {}'.format(
str(e)
))
gdc_qag_percentage_response = np.nan
try:
# extract gdc_qag percentage from LLM response
gdc_qag_base_stat = float(re.search(pattern, gdc_qag_percentage_response).group(1))
except Exception as e:
print('unable to extract percentage from gdc_qag percentage response {}'.format(
str(e)))
gdc_qag_base_stat = np.nan
# llama-3B base output
llama_base_output = row["llama_base_output"]
try:
# extract llama percentage from llama base output
llama_base_stat = float(re.search(pattern, llama_base_output).group(1))
except Exception as e:
print('unable to extract llama base stat {}'.format(str(e)))
llama_base_stat = np.nan
############ postprocess LLM description + percentage ###############
final_gdc_qag_desc_response = postprocess_llm_description(row['descriptive_response'])
gdc_qag_base_stat, final_gdc_qag_percentage_response = postprocess_percentage_response(
gdc_qag_base_stat, gdc_result_percentage, gdc_qag_percentage_response
)
final_gdc_qag_response = final_gdc_qag_desc_response + ' ' + final_gdc_qag_percentage_response
return pd.Series(
[
llama_base_stat,
gdc_qag_base_stat,
final_gdc_qag_desc_response,
final_gdc_qag_percentage_response,
final_gdc_qag_response
]
)
def format_error_string():
error_string = "Error Executing the query. Please checkout 'Examples' to formulate your search query. To specify cancer types, refer to the Project Name from the Genomic Data Commons, e.g. 'breast invasive carcinoma' for breast cancer."
error_string = f"""
> Query augmented generation error:
> {error_string}
"""
return error_string
def wrap_output(result_str):
return "\n".join(textwrap.wrap(result_str, width=80))
def format_result_string(result):
result_string = f"""
```
Question:
{result['GDC-QAG results']['Question']}
```
```
QAG intermediate outputs:
Gene entities: {result['GDC-QAG results']['Gene entities']}
Mutation entities: {result['GDC-QAG results']['Mutation entities']}
Cancer entities: {result['GDC-QAG results']['Cancer entities']}
Intent: {result['GDC-QAG results']['Intent']}
```
```
QAG final response:
{result['GDC-QAG results']['Query augmented generation']}
```
"""
print('result_string {}'.format(result_string))
return result_string
def format_result_string_multi(result):
multi_result = "\n".join(result['response_with_cancer'].astype(str))
print('multi result {}'.format(multi_result))
# test final response only
# test adding entities soonafter
result_string = f"""
```
QAG final response:
{multi_result}
```
"""
print('result_string {}'.format(result_string))
return result_string
@utilities.timeit
def execute_pipeline(question: str):
df = pd.DataFrame({"questions": [question]})
print(f"\n\nQuestion received: {question}\n")
try:
# queries input file
df[
[
"llama_base_output",
"gdc_result",
"cancer_entities",
"intent",
"gene_entities",
"mutation_entities",
]
] = df["questions"].apply(lambda x: batch_test(x))
df_exploded = df.explode("gdc_result", ignore_index=True)
# generate descriptive response once based on genes and intent
print("\nStep 6: Construct LLM prompts (descriptive) for llama-3B\n")
intent = intent_expansion[df['intent'].iloc[0]]
genes = ','.join(df['gene_entities'].iloc[0])
print('intent, genes {} {}'.format(intent, genes))
descriptive_prompt = construct_modified_query_description(genes, intent)
print('desc prompt {}'.format(descriptive_prompt))
print("\nStep 7: Generate LLM response R (descriptive) on query augmented prompts\n")
descriptive_response = generate_descriptive_response(descriptive_prompt)
print('desc response {}'.format(descriptive_response))
if not descriptive_response.endswith('.'):
descriptive_response += '.'
df_exploded[['descriptive_prompt', 'descriptive_response']] = descriptive_prompt, descriptive_response
df_exploded[["percentage_prompt", "percentage_response"]] = df_exploded.apply(
lambda x: get_prefinal_response(x), axis=1)
### postprocess response
print("\nStep 8: Final check and confirmation\n")
df_exploded[
[
"llama_base_stat",
"gdc_qag_base_stat",
"final_gdc_qag_desc_response",
"final_gdc_qag_percentage_response",
"final_gdc_qag_response"
]
] = df_exploded.apply(lambda x: postprocess_response(x), axis=1)
final_columns = utilities.get_final_columns()
result = df_exploded[final_columns].copy()
result.rename(
columns={
"llama_base_output": "llama-3B baseline output",
"descriptive_prompt": "Descriptive prompt",
"percentage_prompt": "Percentage prompt",
"gdc_result": "GDC Result",
"gdc_qag_base_stat": "GDC-QAG frequency",
"llama_base_stat": "llama-3B baseline frequency",
"final_gdc_qag_response": "Query augmented generation",
"intent": "Intent",
"cancer_entities": "Cancer entities",
"gene_entities": "Gene entities",
"mutation_entities": "Mutation entities",
"questions": "Question",
},
inplace=True,
)
result.index = ["GDC-QAG results"] * len(result)
print("completed")
print("\nWriting result string now\n")
if result.shape[0] > 1:
result['response_with_cancer'] = result['Query augmented generation'] + '.' + result['GDC Result']
print('multi cancer result {}'.format(result))
result_string = format_result_string_multi(result)
else:
result = result.T.to_dict()
result_string = format_result_string(result)
except Exception as e:
result_string = format_error_string()
return result_string
def visible_component(input_text):
return gr.update(value="WHATEVER")
# Create Gradio interface
with gr.Blocks(title="GDC QAG MCP server", css="""
#format-textbox label {
font-size: 25px;
font-weight: bold;
}
#format-textbox input::placeholder {
font-size: 20px;
}
#format-textbox .svelte-1ipelgc {
font-size: 18px;
}
""") as GDC_QAG_QUERY:
gr.Markdown(
"""
# GDC-QAG Service
"""
)
with gr.Row():
query_input = gr.Textbox(
lines=3,
label="Please see 'Examples' below to test sample queries. Formulate your search query similar to examples. To specify cancer types, refer to the Project Name from the Genomic Data Commons, e.g. 'breast invasive carcinoma' for breast cancer.",
placeholder='e.g. "What is the co-occurence frequency of somatic homozygous deletions in CDKN2A and CDKN2B in the mesothelioma project TCGA-MESO in the genomic data commons?"',
info="Required: Enter your query. Please retry query if GDC API is unavailable or connection aborts.",
elem_id="format-textbox"
)
gr.Examples(
examples=EXAMPLE_INPUTS, inputs=query_input, example_labels=EXAMPLE_LABELS
)
execute_button = gr.Button("Execute", variant="primary")
output = gr.Markdown("""
### Query Result
_The result of the query will appear here_
"""
)
execute_button.click(
fn=execute_pipeline,
inputs=[query_input],
outputs=output,
)
if __name__ == "__main__":
GDC_QAG_QUERY.launch(mcp_server=True, show_api=True)
|