Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import re | |
| from types import SimpleNamespace | |
| import gradio as gr | |
| from itertools import chain | |
| import pandas as pd | |
| import numpy as np | |
| import spaces | |
| import spacy | |
| import torch | |
| import textwrap | |
| from guidance import gen as guidance_gen | |
| from guidance.models import Transformers | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| BertForSequenceClassification, | |
| BertTokenizer, | |
| set_seed, | |
| ) | |
| from sentence_transformers import SentenceTransformer, util | |
| from methods import gdc_api_calls, utilities | |
| # set up various tokens | |
| hf_TOKEN = os.environ.get("hf_svc_ctds", False) | |
| # disable tokenizer parallelism | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| EXAMPLE_INPUTS = [ | |
| "What is the co-occurence frequency of somatic homozygous deletions in CDKN2A and CDKN2B in the mesothelioma project TCGA-MESO in the genomic data commons?", | |
| "What is the co-occurence frequency of somatic heterozygous deletions in BRCA2 and NF1 in the Kidney Chromophobe TCGA-KICH project in the genomic data commons?", | |
| "What percentage of ovarian serous cystadenocarcinoma cases have a somatic heterozygous deletion in BRCA1 and simple somatic mutations in BRCA1 in the genomic data commons?", | |
| "What fraction of cases have simple somatic mutations or copy number variants in ALK in Lung Adenocarcinoma TCGA-LUAD project in the genomic data commons?", | |
| "How often is microsatellite instability observed in Colon Adenocarcinoma TCGA-COAD project in the genomic data commons?", | |
| "How often is the BRAF V600E mutation found in Skin Cutaneous Melanoma TCGA-SKCM project in the genomic data commons?", | |
| "What is the co-occurence frequency of IDH1 R132H and TP53 R273C simple somatic mutations in the low grade glioma project TCGA-LGG in the genomic data commons?", | |
| "In Lung Adenocarcinoma TCGA-LUAD project data from the genomic data commons, what is the frequency of ALK amplification?" | |
| ] | |
| EXAMPLE_LABELS = [ | |
| "combination homozygous deletions", | |
| "combination heterozygous deletions", | |
| "heterozygous deletion and somatic mutations", | |
| "copy number variants or somatic mutations", | |
| "microsatellite-instability", | |
| "simple somatic mutation", | |
| "combination somatic mutations", | |
| "single gene amplification" | |
| ] | |
| # for natural language gene and intent descriptions | |
| intent_expansion = { | |
| 'cnv_and_ssm': 'copy number variants or simple somatic mutations', | |
| 'freq_cnv_loss_or_gain': 'copy number variant losses or gains', | |
| 'msi_h_frequency': 'microsatellite instability', | |
| 'freq_cnv_loss_or_gain_comb': 'copy number variant losses or gains', | |
| 'ssm_frequency': 'simple somatic mutations', | |
| 'top_cases_counts_by_gene': 'copy number variants or simple somatic mutations' | |
| } | |
| # set up requirements: models and data | |
| print("getting gdc project information") | |
| project_mappings = gdc_api_calls.get_gdc_project_ids(start=0, stop=86) | |
| print("loading intent model and tokenizer") | |
| model_id = "uc-ctds/query_intent" | |
| intent_tok = AutoTokenizer.from_pretrained( | |
| model_id, trust_remote_code=True, token=hf_TOKEN | |
| ) | |
| intent_model = BertForSequenceClassification.from_pretrained(model_id, token=hf_TOKEN) | |
| intent_model = intent_model.to("cuda").eval() | |
| # load sentence transformer model to test cancer embeddings | |
| print('loading sentence transformer model') | |
| st_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
| st_model = st_model.to("cuda").eval() | |
| print("loading gdc genes and mutations") | |
| gdc_genes_mutations = utilities.load_gdc_genes_mutations_hf(hf_TOKEN) | |
| print("loading llama-3B model and tokenizer") | |
| model_id = "meta-llama/Llama-3.2-3B-Instruct" | |
| tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, token=hf_TOKEN) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, torch_dtype=torch.float16, trust_remote_code=True, token=hf_TOKEN | |
| ) | |
| model = model.to("cuda").eval() | |
| # global init to test guidance speed up | |
| base_lm = Transformers(model=model, tokenizer=tok) | |
| def infer_mutation_entities(gene_entities, query): | |
| mutation_entities = [] | |
| for g in gene_entities: | |
| for m in gdc_genes_mutations[g]: | |
| if m in query: | |
| mutation_entities.append(m) | |
| return mutation_entities | |
| def infer_gene_entities_from_query(query): | |
| entities = [] | |
| # gene recognition with simple dict-based method | |
| for g in gdc_genes_mutations.keys(): | |
| if (g in query) and (g in query.split(" ")): | |
| entities.append(g) | |
| return entities | |
| def get_project_embeddings(): | |
| project_rows = [] | |
| for k,v in project_mappings.items(): | |
| new_v = [item.replace(',', '') for item in v] | |
| combined = ','.join([k] + new_v) | |
| project_rows.append(combined) | |
| row_embeddings = st_model.encode(project_rows, convert_to_tensor=True, device='cuda') | |
| return project_rows, row_embeddings.cpu().numpy() | |
| def check_if_project_id_in_query(query): | |
| # check if mention of project keys | |
| # e.g. TCGA-BRCA in query | |
| project_list = project_mappings.keys() | |
| cancer_entities = [ | |
| potential_ce | |
| for potential_ce in query.split(" ") | |
| if potential_ce in project_list | |
| ] | |
| return cancer_entities | |
| def proj_id_and_partial_match(query, initial_cancer_entities): | |
| final_entities = [] | |
| if initial_cancer_entities: | |
| # print('checking for full match between initial cancer entities and GDC project descriptions') | |
| # check for match with project_mapping values | |
| # e.g. match "ovarian serous cystadenocarcinoma" to TCGA-OV project | |
| for ic in initial_cancer_entities: | |
| for k, v in project_mappings.items(): | |
| for c in v: | |
| if ic in c.lower(): | |
| final_entities.append(k) | |
| else: | |
| # print('no initial cancer entities, check for full match between query terms and GDC project descriptions') | |
| for term in query.lower().split(" "): | |
| for k, v in project_mappings.items(): | |
| for c in v: | |
| if term in c.lower(): | |
| final_entities.append(k) | |
| return list(set(final_entities)) | |
| def get_top_k_scores(query, row_embeddings, top_k=20): | |
| query_embedding = st_model.encode(query, convert_to_tensor=True, device='cuda') | |
| row_embeddings = torch.from_numpy(row_embeddings).float().to('cuda') | |
| cosine_scores = util.cos_sim(query_embedding, row_embeddings) | |
| top_results = torch.topk(cosine_scores, k=top_k) | |
| # convert to CPU and return | |
| top_results_scores = top_results.values.cpu().tolist() | |
| top_results_indices = top_results.indices.cpu().tolist() | |
| return top_results_scores, top_results_indices | |
| def get_top_k_cancer_entities(project_rows, top_results_scores, top_results_indices): | |
| top_cancer_entities = [] | |
| for idx, score in enumerate(top_results_scores[0]): | |
| if score > 0.5: | |
| row_idx = top_results_indices[0][idx] | |
| print('best row, score: {} {}'.format(project_rows[row_idx], score)) | |
| top_cancer_entities.append([project_rows[row_idx], score]) | |
| try: | |
| top_projects = [sublist[0].split(',')[0] for sublist in top_cancer_entities] | |
| except Exception as e: | |
| top_projects = [] | |
| return top_projects | |
| def postprocess_cancer_entities(initial_cancer_entities, query): | |
| # print('initial cancer entities {}'.format(initial_cancer_entities)) | |
| # get project embeddings | |
| print('loading cancer embeddings') | |
| project_rows, row_embeddings = get_project_embeddings() | |
| final_entities = check_if_project_id_in_query(query) | |
| if final_entities: | |
| return final_entities | |
| else: | |
| if initial_cancer_entities: | |
| # first query GDC projects endpt | |
| # print('test 1 (w/ initial entities): querying GDC projects endpt for project_id') | |
| gdc_project_match = gdc_api_calls.map_cancer_entities_to_project( | |
| initial_cancer_entities, project_mappings | |
| ) | |
| # print('mapped projects to ids {}'.format(gdc_project_match)) | |
| if gdc_project_match.values(): | |
| final_entities = list(gdc_project_match.values()) | |
| if not final_entities: | |
| # print('test 2 (w/ initial entities): no result from GDC projects endpt, check for matches ' | |
| # 'between query terms and gdc project_mappings') | |
| final_entities = proj_id_and_partial_match( | |
| query, initial_cancer_entities | |
| ) | |
| # try embedding based match | |
| if not final_entities: | |
| print('Test embedding based match') | |
| for i in initial_cancer_entities: | |
| top_results_scores, top_results_indices = get_top_k_scores(i, row_embeddings) | |
| c_entities = get_top_k_cancer_entities(project_rows, top_results_scores, top_results_indices) | |
| final_entities.append(c_entities) | |
| final_entities = list(chain.from_iterable(final_entities)) | |
| else: | |
| # no initial_cancer_entities | |
| # check project_mappings keys/values for matches with query terms | |
| # print('test 3 (w/o initial entities): no result from GDC projects endpt, check for matches ' | |
| # 'between query terms and gdc project_mappings') | |
| final_entities = proj_id_and_partial_match( | |
| query, initial_cancer_entities | |
| ) | |
| return final_entities | |
| def execute_api_call(intent, gene_entities, mutation_entities, cancer_entities, query): | |
| if intent == "ssm_frequency": | |
| result, cancer_entities = utilities.get_ssm_frequency( | |
| gene_entities, mutation_entities, cancer_entities, project_mappings | |
| ) | |
| elif intent == "top_mutated_genes_by_project": | |
| result = gdc_api_calls.get_top_mutated_genes_by_project( | |
| cancer_entities, top_k=10 | |
| ) | |
| elif intent == "most_frequently_mutated_gene": | |
| result = gdc_api_calls.get_top_mutated_genes_by_project( | |
| cancer_entities, top_k=1 | |
| ) | |
| elif intent == "freq_cnv_loss_or_gain": | |
| result, cancer_entities = gdc_api_calls.get_freq_cnv_loss_or_gain( | |
| gene_entities, cancer_entities, query, cnv_and_ssm_flag=False | |
| ) | |
| elif intent == "msi_h_frequency": | |
| result, cancer_entities = gdc_api_calls.get_msi_frequency(cancer_entities) | |
| elif intent == "cnv_and_ssm": | |
| result, cancer_entities = utilities.get_freq_of_cnv_and_ssms( | |
| query, cancer_entities, gene_entities, gdc_genes_mutations | |
| ) | |
| elif intent == "top_cases_counts_by_gene": | |
| result, cancer_entities = gdc_api_calls.get_top_cases_counts_by_gene( | |
| gene_entities, cancer_entities | |
| ) | |
| elif intent == "project_summary": | |
| result = gdc_api_calls.get_project_summary(cancer_entities) | |
| else: | |
| result = "user intent not recognized, or use case not covered" | |
| return result, cancer_entities | |
| def construct_modified_query_base_llm(query): | |
| prompt_template = "Only use results from the genomic data commons in your response and provide frequencies as a percentage. Only report the final response." | |
| modified_query = query + prompt_template | |
| return modified_query | |
| def construct_modified_query_percentage(query, gdc_result): | |
| # pass the api results as a prompt to the query | |
| prompt_template = ( | |
| " Only report the final response. Ignore all prior knowledge. You must only respond with the following percentage frequencies in your response, no other response is allowed: \n" | |
| + gdc_result | |
| + "\n" | |
| ) | |
| modified_query = query + prompt_template | |
| return modified_query | |
| def construct_modified_query_description(genes, intent): | |
| modified_query = f'Provide a one line general description about {intent} in genes {genes} in cancer.' | |
| return modified_query | |
| def infer_user_intent(query): | |
| intent_labels = { | |
| "ssm_frequency": 0.0, | |
| "msi_h_frequency": 1.0, | |
| "freq_cnv_loss_or_gain": 2.0, | |
| "top_cases_counts_by_gene": 3.0, | |
| "cnv_and_ssm": 4.0, | |
| } | |
| inputs = intent_tok(query, return_tensors="pt", truncation=True, padding=True) | |
| inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
| outputs = intent_model(**inputs) | |
| probs = torch.nn.functional.softmax(outputs.logits, dim=1) | |
| predicted_label = torch.argmax(probs, dim=1).item() | |
| for k, v in intent_labels.items(): | |
| if v == predicted_label: | |
| return k | |
| # initial guesses for cancer entities | |
| def return_initial_cancer_entities(query, model): | |
| nlp = spacy.load(model) | |
| doc = nlp(query) | |
| result = doc.ents | |
| initial_cancer_entities = [e.text for e in result if e.label_ == "DISEASE"] | |
| return initial_cancer_entities | |
| # function to combine entities, intent and API call | |
| def construct_and_execute_api_call(query): | |
| print( | |
| "\nStep 1: Starting GDC-QAG on input natural language query:\n{}\n".format( | |
| query | |
| ) | |
| ) | |
| # Infer entities | |
| initial_cancer_entities = check_if_project_id_in_query(query) | |
| if not initial_cancer_entities: | |
| try: | |
| initial_cancer_entities = return_initial_cancer_entities( | |
| query, model="en_ner_bc5cdr_md" | |
| ) | |
| print('initial cancer entities {}'.format(initial_cancer_entities)) | |
| except Exception as e: | |
| print("unable to guess cancer entities {}".format(str(e))) | |
| initial_cancer_entities = [] | |
| cancer_entities = postprocess_cancer_entities( | |
| initial_cancer_entities=initial_cancer_entities, query=query | |
| ) | |
| # if cancer entities is empty from above methods return all projects | |
| if not cancer_entities: | |
| cancer_entities = list(project_mappings.keys()) | |
| gene_entities = infer_gene_entities_from_query(query) | |
| mutation_entities = infer_mutation_entities( | |
| gene_entities=gene_entities, | |
| query=query | |
| ) | |
| print("\nStep 2: Entity Extraction\n") | |
| print("gene entities {}".format(gene_entities)) | |
| print("mutation entities {}".format(mutation_entities)) | |
| print("cancer entities {}".format(cancer_entities)) | |
| # infer user intent | |
| intent = infer_user_intent(query) | |
| print("\nStep 3: Intent Inference:\n{}\n".format(intent)) | |
| try: | |
| print("\nStep 4: API call builder for intent {}\n".format(intent)) | |
| api_call_result, cancer_entities = execute_api_call( | |
| intent, gene_entities, mutation_entities, cancer_entities, query | |
| ) | |
| except Exception as e: | |
| print("unable to process query {} {}".format(query, str(e))) | |
| api_call_result = [] | |
| cancer_entities = [] | |
| return SimpleNamespace( | |
| gdc_result=api_call_result, | |
| cancer_entities=cancer_entities, | |
| intent=intent, | |
| gene_entities=gene_entities, | |
| mutation_entities=mutation_entities, | |
| ) | |
| # generate llama model response | |
| def generate_percentage_response(modified_query): | |
| # set_seed(1042) | |
| regex = "The final response is: \d*\.\d*%" | |
| lm = base_lm | |
| lm += modified_query | |
| lm += guidance_gen("pct_response", n=1, temperature=0, max_tokens=40, regex=regex) | |
| return lm["pct_response"] | |
| # generate llama model descriptive response | |
| def generate_descriptive_response(modified_query): | |
| lm = base_lm | |
| lm += modified_query | |
| lm += guidance_gen( | |
| "desc_response", | |
| n=1, | |
| temperature=0, | |
| max_tokens=100, | |
| regex="^[^\\n]*[.\S+]$", | |
| ) | |
| return lm["desc_response"] | |
| def batch_test(query): | |
| modified_query = construct_modified_query_base_llm(query) | |
| print(f"obtain baseline llama-3B response on modified query: {modified_query}") | |
| llama_base_output = generate_percentage_response(modified_query) | |
| print(f"llama-3B baseline response: {llama_base_output}") | |
| try: | |
| result = construct_and_execute_api_call(query) | |
| except Exception as e: | |
| # unable to compute at this time, recheck | |
| result.gdc_result = [] | |
| result.cancer_entities = [] | |
| # if there is not a helper output for each unique cancer entity | |
| # log error to inspect and reprocess query later | |
| try: | |
| len(result.gdc_result) == len(result.cancer_entities) | |
| except Exception as e: | |
| msg = "there is not a unique helper output for each unique \ | |
| cancer entity in {}".format( | |
| query | |
| ) | |
| print("exception {}".format(msg)) | |
| result.gdc_result = [] | |
| result.cancer_entities = [] | |
| return pd.Series( | |
| [ | |
| llama_base_output, | |
| result.gdc_result, | |
| result.cancer_entities, | |
| result.intent, | |
| result.gene_entities, | |
| result.mutation_entities, | |
| ] | |
| ) | |
| def get_prefinal_response(row): | |
| try: | |
| query = row["questions"] | |
| gdc_result = row["gdc_result"] | |
| except Exception as e: | |
| print(f"unable to retrieve query: {query} or gdc_result: {gdc_result}") | |
| print("\nStep 6: Construct LLM prompts (percentage) for llama-3B\n") | |
| percentage_prompt = construct_modified_query_percentage(query, gdc_result) | |
| print("\nStep 7: Generate LLM response R (percentage) on query augmented prompts\n") | |
| percentage_response = generate_percentage_response(percentage_prompt) | |
| percentage_response = re.sub( | |
| r'final response', 'frequency for your query', percentage_response) | |
| return pd.Series([ | |
| percentage_prompt, percentage_response | |
| ]) | |
| def postprocess_llm_description(descriptive_response): | |
| try: | |
| num_tokens = len(tok.encode(descriptive_response)) | |
| if num_tokens < 100: | |
| postprocessed_desc_response = descriptive_response | |
| else: | |
| response_list = re.split(r'\.(?!\d+%)', descriptive_response) | |
| # remove empty elements | |
| filtered_list = list(filter(None, response_list)) | |
| postprocessed_desc_response = '.'.join(filtered_list[:-1]) | |
| except Exception as e: | |
| print('unable to postprocess LLM gene description {}'.format( | |
| str(e) | |
| )) | |
| postprocessed_desc_response = 'unable to postprocess LLM gene description' | |
| if not postprocessed_desc_response.endswith('.'): | |
| postprocessed_desc_response += '.' | |
| return postprocessed_desc_response | |
| def postprocess_percentage_response( | |
| gdc_qag_base_stat, gdc_result_percentage, gdc_qag_percentage_response): | |
| try: | |
| # check/confirm if gdc_qag_base_stat percentage == gdc_result_percentage | |
| # change it, if not | |
| if gdc_qag_base_stat != gdc_result_percentage: | |
| gdc_qag_base_stat = gdc_result_percentage | |
| final_gdc_qag_percentage_response = 'The frequency for your query is: {}%'.format( | |
| gdc_qag_base_stat) | |
| else: | |
| final_gdc_qag_percentage_response = gdc_qag_percentage_response | |
| except Exception as e: | |
| print('unable to postprocess percentage frequency {}'.format( | |
| str(e) | |
| )) | |
| final_gdc_qag_percentage_response = 'unable to postprocess percentage frequency' | |
| return gdc_qag_base_stat, final_gdc_qag_percentage_response | |
| def postprocess_response(row): | |
| # three goals: | |
| # goal 1: | |
| # check/confirm the results in gdc-qag percentage response | |
| # return a percentage response for gdc-qag | |
| # goal 2: | |
| # postprocess descriptive response | |
| # goal 3: | |
| # return concatenated final response from gdc_qag | |
| # (descriptive response + percentage response) | |
| pattern = r".*?(\d*\.\d*)%.*?" | |
| ###### various inputs ############################### | |
| try: | |
| # this is the result obtained in GDC-QAG via API | |
| gdc_result = row["gdc_result"] | |
| except Exception as e: | |
| print('GDC Result not found in gdc_qag output, returning nan {}'.format( | |
| str(e) | |
| )) | |
| gdc_result = np.nan | |
| try: | |
| # extract gdc_result percentage from gdc_result | |
| match = re.search(pattern, gdc_result) | |
| if match: | |
| gdc_result_percentage = float(match.group(1)) | |
| else: | |
| gdc_result_percentage = np.nan | |
| print('no data available in gdc') | |
| except Exception as e: | |
| print('unable to extract percentage from gdc result {}'.format( | |
| str(e))) | |
| gdc_result_percentage = np.nan | |
| try: | |
| # this is the LLM generated response with freq, after seeing gdc_result | |
| gdc_qag_percentage_response = row['percentage_response'] | |
| except Exception as e: | |
| print('LLM generated gdc_qag percentage response not found, returning nan {}'.format( | |
| str(e) | |
| )) | |
| gdc_qag_percentage_response = np.nan | |
| try: | |
| # extract gdc_qag percentage from LLM response | |
| gdc_qag_base_stat = float(re.search(pattern, gdc_qag_percentage_response).group(1)) | |
| except Exception as e: | |
| print('unable to extract percentage from gdc_qag percentage response {}'.format( | |
| str(e))) | |
| gdc_qag_base_stat = np.nan | |
| # llama-3B base output | |
| llama_base_output = row["llama_base_output"] | |
| try: | |
| # extract llama percentage from llama base output | |
| llama_base_stat = float(re.search(pattern, llama_base_output).group(1)) | |
| except Exception as e: | |
| print('unable to extract llama base stat {}'.format(str(e))) | |
| llama_base_stat = np.nan | |
| ############ postprocess LLM description + percentage ############### | |
| final_gdc_qag_desc_response = postprocess_llm_description(row['descriptive_response']) | |
| gdc_qag_base_stat, final_gdc_qag_percentage_response = postprocess_percentage_response( | |
| gdc_qag_base_stat, gdc_result_percentage, gdc_qag_percentage_response | |
| ) | |
| final_gdc_qag_response = final_gdc_qag_desc_response + ' ' + final_gdc_qag_percentage_response | |
| return pd.Series( | |
| [ | |
| llama_base_stat, | |
| gdc_qag_base_stat, | |
| final_gdc_qag_desc_response, | |
| final_gdc_qag_percentage_response, | |
| final_gdc_qag_response | |
| ] | |
| ) | |
| def format_error_string(): | |
| error_string = "Error Executing the query. Please checkout 'Examples' to formulate your search query. To specify cancer types, refer to the Project Name from the Genomic Data Commons, e.g. 'breast invasive carcinoma' for breast cancer." | |
| error_string = f""" | |
| > Query augmented generation error: | |
| > {error_string} | |
| """ | |
| return error_string | |
| def wrap_output(result_str): | |
| return "\n".join(textwrap.wrap(result_str, width=80)) | |
| def format_result_string(result): | |
| result_string = f""" | |
| ``` | |
| Question: | |
| {result['GDC-QAG results']['Question']} | |
| ``` | |
| ``` | |
| QAG intermediate outputs: | |
| Gene entities: {result['GDC-QAG results']['Gene entities']} | |
| Mutation entities: {result['GDC-QAG results']['Mutation entities']} | |
| Cancer entities: {result['GDC-QAG results']['Cancer entities']} | |
| Intent: {result['GDC-QAG results']['Intent']} | |
| ``` | |
| ``` | |
| QAG final response: | |
| {result['GDC-QAG results']['Query augmented generation']} | |
| ``` | |
| """ | |
| print('result_string {}'.format(result_string)) | |
| return result_string | |
| def format_result_string_multi(result): | |
| multi_result = "\n".join(result['response_with_cancer'].astype(str)) | |
| print('multi result {}'.format(multi_result)) | |
| # test final response only | |
| # test adding entities soonafter | |
| result_string = f""" | |
| ``` | |
| QAG final response: | |
| {multi_result} | |
| ``` | |
| """ | |
| print('result_string {}'.format(result_string)) | |
| return result_string | |
| def execute_pipeline(question: str): | |
| df = pd.DataFrame({"questions": [question]}) | |
| print(f"\n\nQuestion received: {question}\n") | |
| try: | |
| # queries input file | |
| df[ | |
| [ | |
| "llama_base_output", | |
| "gdc_result", | |
| "cancer_entities", | |
| "intent", | |
| "gene_entities", | |
| "mutation_entities", | |
| ] | |
| ] = df["questions"].apply(lambda x: batch_test(x)) | |
| df_exploded = df.explode("gdc_result", ignore_index=True) | |
| # generate descriptive response once based on genes and intent | |
| print("\nStep 6: Construct LLM prompts (descriptive) for llama-3B\n") | |
| intent = intent_expansion[df['intent'].iloc[0]] | |
| genes = ','.join(df['gene_entities'].iloc[0]) | |
| print('intent, genes {} {}'.format(intent, genes)) | |
| descriptive_prompt = construct_modified_query_description(genes, intent) | |
| print('desc prompt {}'.format(descriptive_prompt)) | |
| print("\nStep 7: Generate LLM response R (descriptive) on query augmented prompts\n") | |
| descriptive_response = generate_descriptive_response(descriptive_prompt) | |
| print('desc response {}'.format(descriptive_response)) | |
| if not descriptive_response.endswith('.'): | |
| descriptive_response += '.' | |
| df_exploded[['descriptive_prompt', 'descriptive_response']] = descriptive_prompt, descriptive_response | |
| df_exploded[["percentage_prompt", "percentage_response"]] = df_exploded.apply( | |
| lambda x: get_prefinal_response(x), axis=1) | |
| ### postprocess response | |
| print("\nStep 8: Final check and confirmation\n") | |
| df_exploded[ | |
| [ | |
| "llama_base_stat", | |
| "gdc_qag_base_stat", | |
| "final_gdc_qag_desc_response", | |
| "final_gdc_qag_percentage_response", | |
| "final_gdc_qag_response" | |
| ] | |
| ] = df_exploded.apply(lambda x: postprocess_response(x), axis=1) | |
| final_columns = utilities.get_final_columns() | |
| result = df_exploded[final_columns].copy() | |
| result.rename( | |
| columns={ | |
| "llama_base_output": "llama-3B baseline output", | |
| "descriptive_prompt": "Descriptive prompt", | |
| "percentage_prompt": "Percentage prompt", | |
| "gdc_result": "GDC Result", | |
| "gdc_qag_base_stat": "GDC-QAG frequency", | |
| "llama_base_stat": "llama-3B baseline frequency", | |
| "final_gdc_qag_response": "Query augmented generation", | |
| "intent": "Intent", | |
| "cancer_entities": "Cancer entities", | |
| "gene_entities": "Gene entities", | |
| "mutation_entities": "Mutation entities", | |
| "questions": "Question", | |
| }, | |
| inplace=True, | |
| ) | |
| result.index = ["GDC-QAG results"] * len(result) | |
| print("completed") | |
| print("\nWriting result string now\n") | |
| if result.shape[0] > 1: | |
| result['response_with_cancer'] = result['Query augmented generation'] + '.' + result['GDC Result'] | |
| print('multi cancer result {}'.format(result)) | |
| result_string = format_result_string_multi(result) | |
| else: | |
| result = result.T.to_dict() | |
| result_string = format_result_string(result) | |
| except Exception as e: | |
| result_string = format_error_string() | |
| return result_string | |
| def visible_component(input_text): | |
| return gr.update(value="WHATEVER") | |
| # Create Gradio interface | |
| with gr.Blocks(title="GDC QAG MCP server", css=""" | |
| #format-textbox label { | |
| font-size: 25px; | |
| font-weight: bold; | |
| } | |
| #format-textbox input::placeholder { | |
| font-size: 20px; | |
| } | |
| #format-textbox .svelte-1ipelgc { | |
| font-size: 18px; | |
| } | |
| """) as GDC_QAG_QUERY: | |
| gr.Markdown( | |
| """ | |
| # GDC-QAG Service | |
| """ | |
| ) | |
| with gr.Row(): | |
| query_input = gr.Textbox( | |
| lines=3, | |
| label="Please see 'Examples' below to test sample queries. Formulate your search query similar to examples. To specify cancer types, refer to the Project Name from the Genomic Data Commons, e.g. 'breast invasive carcinoma' for breast cancer.", | |
| placeholder='e.g. "What is the co-occurence frequency of somatic homozygous deletions in CDKN2A and CDKN2B in the mesothelioma project TCGA-MESO in the genomic data commons?"', | |
| info="Required: Enter your query. Please retry query if GDC API is unavailable or connection aborts.", | |
| elem_id="format-textbox" | |
| ) | |
| gr.Examples( | |
| examples=EXAMPLE_INPUTS, inputs=query_input, example_labels=EXAMPLE_LABELS | |
| ) | |
| execute_button = gr.Button("Execute", variant="primary") | |
| output = gr.Markdown(""" | |
| ### Query Result | |
| _The result of the query will appear here_ | |
| """ | |
| ) | |
| execute_button.click( | |
| fn=execute_pipeline, | |
| inputs=[query_input], | |
| outputs=output, | |
| ) | |
| if __name__ == "__main__": | |
| GDC_QAG_QUERY.launch(mcp_server=True, show_api=True) | |