Spaces:
Sleeping
Sleeping
embeddings based project match, other optims
Browse files
app.py
CHANGED
|
@@ -4,6 +4,7 @@ import re
|
|
| 4 |
from types import SimpleNamespace
|
| 5 |
|
| 6 |
import gradio as gr
|
|
|
|
| 7 |
import pandas as pd
|
| 8 |
import spaces
|
| 9 |
import spacy
|
|
@@ -18,6 +19,7 @@ from transformers import (
|
|
| 18 |
BertTokenizer,
|
| 19 |
set_seed,
|
| 20 |
)
|
|
|
|
| 21 |
|
| 22 |
from methods import gdc_api_calls, utilities
|
| 23 |
|
|
@@ -74,6 +76,10 @@ intent_tok = AutoTokenizer.from_pretrained(
|
|
| 74 |
intent_model = BertForSequenceClassification.from_pretrained(model_id, token=hf_TOKEN)
|
| 75 |
intent_model = intent_model.to("cuda").eval()
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
print("loading gdc genes and mutations")
|
| 79 |
gdc_genes_mutations = utilities.load_gdc_genes_mutations_hf(hf_TOKEN)
|
|
@@ -112,6 +118,17 @@ def infer_gene_entities_from_query(query):
|
|
| 112 |
|
| 113 |
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
def check_if_project_id_in_query(query):
|
| 116 |
# check if mention of project keys
|
| 117 |
# e.g. TCGA-BRCA in query
|
|
@@ -147,11 +164,31 @@ def proj_id_and_partial_match(query, initial_cancer_entities):
|
|
| 147 |
|
| 148 |
|
| 149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
@utilities.timeit
|
| 152 |
def postprocess_cancer_entities(initial_cancer_entities, query):
|
| 153 |
# print('initial cancer entities {}'.format(initial_cancer_entities))
|
| 154 |
-
|
| 155 |
final_entities = check_if_project_id_in_query(query)
|
| 156 |
if final_entities:
|
| 157 |
return final_entities
|
|
@@ -171,6 +208,13 @@ def postprocess_cancer_entities(initial_cancer_entities, query):
|
|
| 171 |
final_entities = proj_id_and_partial_match(
|
| 172 |
query, initial_cancer_entities
|
| 173 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
else:
|
| 175 |
# no initial_cancer_entities
|
| 176 |
# check project_mappings keys/values for matches with query terms
|
|
@@ -400,28 +444,20 @@ def batch_test(query):
|
|
| 400 |
def get_prefinal_response(row):
|
| 401 |
try:
|
| 402 |
query = row["questions"]
|
| 403 |
-
genes = ','.join(row['gene_entities'])
|
| 404 |
gdc_result = row["gdc_result"]
|
| 405 |
except Exception as e:
|
| 406 |
print(f"unable to retrieve query: {query} or gdc_result: {gdc_result}")
|
| 407 |
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
print("\nStep 6: Construct LLM prompts for llama-3B\n")
|
| 411 |
-
descriptive_prompt = construct_modified_query_description(genes, intent)
|
| 412 |
percentage_prompt = construct_modified_query_percentage(query, gdc_result)
|
| 413 |
|
| 414 |
-
print("\nStep 7: Generate LLM response R on query augmented prompts\n")
|
| 415 |
-
descriptive_response = generate_descriptive_response(descriptive_prompt)
|
| 416 |
-
if not descriptive_response.endswith('.'):
|
| 417 |
-
descriptive_response += '.'
|
| 418 |
|
| 419 |
percentage_response = generate_percentage_response(percentage_prompt)
|
| 420 |
percentage_response = re.sub(
|
| 421 |
r'final response', 'frequency for your query', percentage_response)
|
| 422 |
return pd.Series([
|
| 423 |
-
|
| 424 |
-
descriptive_response, percentage_response
|
| 425 |
])
|
| 426 |
|
| 427 |
|
|
@@ -630,7 +666,21 @@ def execute_pipeline(question: str):
|
|
| 630 |
]
|
| 631 |
] = df["questions"].apply(lambda x: batch_test(x))
|
| 632 |
df_exploded = df.explode("gdc_result", ignore_index=True)
|
| 633 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 634 |
lambda x: get_prefinal_response(x), axis=1)
|
| 635 |
|
| 636 |
|
|
|
|
| 4 |
from types import SimpleNamespace
|
| 5 |
|
| 6 |
import gradio as gr
|
| 7 |
+
from itertools import chain
|
| 8 |
import pandas as pd
|
| 9 |
import spaces
|
| 10 |
import spacy
|
|
|
|
| 19 |
BertTokenizer,
|
| 20 |
set_seed,
|
| 21 |
)
|
| 22 |
+
from sentence_transformers import SentenceTransformer, util
|
| 23 |
|
| 24 |
from methods import gdc_api_calls, utilities
|
| 25 |
|
|
|
|
| 76 |
intent_model = BertForSequenceClassification.from_pretrained(model_id, token=hf_TOKEN)
|
| 77 |
intent_model = intent_model.to("cuda").eval()
|
| 78 |
|
| 79 |
+
# load sentence transformer model to test cancer embeddings
|
| 80 |
+
st_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
| 81 |
+
st_model = st_model.to("cuda")
|
| 82 |
+
|
| 83 |
|
| 84 |
print("loading gdc genes and mutations")
|
| 85 |
gdc_genes_mutations = utilities.load_gdc_genes_mutations_hf(hf_TOKEN)
|
|
|
|
| 118 |
|
| 119 |
|
| 120 |
|
| 121 |
+
def get_project_embeddings():
|
| 122 |
+
project_rows = []
|
| 123 |
+
for k,v in project_mappings.items():
|
| 124 |
+
new_v = [item.replace(',', '') for item in v]
|
| 125 |
+
combined = ','.join([k] + new_v)
|
| 126 |
+
project_rows.append(combined)
|
| 127 |
+
row_embeddings = model.encode(project_rows, convert_to_tensor=True)
|
| 128 |
+
return project_rows, row_embeddings
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
|
| 132 |
def check_if_project_id_in_query(query):
|
| 133 |
# check if mention of project keys
|
| 134 |
# e.g. TCGA-BRCA in query
|
|
|
|
| 164 |
|
| 165 |
|
| 166 |
|
| 167 |
+
def get_top_k_cancer_entities(query, row_embeddings, project_rows, top_k=20):
|
| 168 |
+
top_cancer_entities = []
|
| 169 |
+
query_embedding = st_model.encode(query, convert_to_tensor=True)
|
| 170 |
+
cosine_scores = util.cos_sim(query_embedding, row_embeddings)
|
| 171 |
+
top_results = torch.topk(cosine_scores, k=top_k)
|
| 172 |
+
top_results_indices = top_results.indices.tolist()
|
| 173 |
+
top_results_scores = top_results.values.tolist()
|
| 174 |
+
print(top_results_scores)
|
| 175 |
+
for idx, score in enumerate(top_results_scores[0]):
|
| 176 |
+
if score > 0.5:
|
| 177 |
+
row_idx = top_results_indices[0][idx]
|
| 178 |
+
print('best row, score: {} {}'.format(project_rows[row_idx], score))
|
| 179 |
+
top_cancer_entities.append([project_rows[row_idx], score])
|
| 180 |
+
try:
|
| 181 |
+
top_projects = [sublist[0].split(',')[0] for sublist in top_cancer_entities]
|
| 182 |
+
except Exception as e:
|
| 183 |
+
top_projects = []
|
| 184 |
+
return top_projects
|
| 185 |
+
|
| 186 |
+
|
| 187 |
|
| 188 |
@utilities.timeit
|
| 189 |
def postprocess_cancer_entities(initial_cancer_entities, query):
|
| 190 |
# print('initial cancer entities {}'.format(initial_cancer_entities))
|
| 191 |
+
project_rows, row_embeddings = get_project_embeddings()
|
| 192 |
final_entities = check_if_project_id_in_query(query)
|
| 193 |
if final_entities:
|
| 194 |
return final_entities
|
|
|
|
| 208 |
final_entities = proj_id_and_partial_match(
|
| 209 |
query, initial_cancer_entities
|
| 210 |
)
|
| 211 |
+
# try embedding based match
|
| 212 |
+
if not final_entities:
|
| 213 |
+
print('Test embedding based match')
|
| 214 |
+
for i in initial_cancer_entities:
|
| 215 |
+
c_entities = get_top_k_cancer_entities(i, row_embeddings, project_rows)
|
| 216 |
+
final_entities.append(c_entities)
|
| 217 |
+
final_entities = list(chain.from_iterable(final_entities))
|
| 218 |
else:
|
| 219 |
# no initial_cancer_entities
|
| 220 |
# check project_mappings keys/values for matches with query terms
|
|
|
|
| 444 |
def get_prefinal_response(row):
|
| 445 |
try:
|
| 446 |
query = row["questions"]
|
|
|
|
| 447 |
gdc_result = row["gdc_result"]
|
| 448 |
except Exception as e:
|
| 449 |
print(f"unable to retrieve query: {query} or gdc_result: {gdc_result}")
|
| 450 |
|
| 451 |
+
print("\nStep 6: Construct LLM prompts (percentage) for llama-3B\n")
|
|
|
|
|
|
|
|
|
|
| 452 |
percentage_prompt = construct_modified_query_percentage(query, gdc_result)
|
| 453 |
|
| 454 |
+
print("\nStep 7: Generate LLM response R (percentage) on query augmented prompts\n")
|
|
|
|
|
|
|
|
|
|
| 455 |
|
| 456 |
percentage_response = generate_percentage_response(percentage_prompt)
|
| 457 |
percentage_response = re.sub(
|
| 458 |
r'final response', 'frequency for your query', percentage_response)
|
| 459 |
return pd.Series([
|
| 460 |
+
percentage_prompt, percentage_response
|
|
|
|
| 461 |
])
|
| 462 |
|
| 463 |
|
|
|
|
| 666 |
]
|
| 667 |
] = df["questions"].apply(lambda x: batch_test(x))
|
| 668 |
df_exploded = df.explode("gdc_result", ignore_index=True)
|
| 669 |
+
|
| 670 |
+
# generate descriptive response once based on genes and intent
|
| 671 |
+
print("\nStep 6: Construct LLM prompts (descriptive) for llama-3B\n")
|
| 672 |
+
intent = intent_expansion[df['intent'].iloc[0]]
|
| 673 |
+
genes = ','.join(df['gene_entities'].iloc[0])
|
| 674 |
+
descriptive_prompt = construct_modified_query_description(genes, intent)
|
| 675 |
+
|
| 676 |
+
print("\nStep 7: Generate LLM response R (descriptive) on query augmented prompts\n")
|
| 677 |
+
descriptive_response = generate_descriptive_response(descriptive_prompt, model, tok)
|
| 678 |
+
if not descriptive_response.endswith('.'):
|
| 679 |
+
descriptive_response += '.'
|
| 680 |
+
|
| 681 |
+
df_exploded[['descriptive_prompt', 'descriptive_response']] = descriptive_prompt, descriptive_response
|
| 682 |
+
|
| 683 |
+
df_exploded[["percentage_prompt", "percentage_response"]] = df_exploded.apply(
|
| 684 |
lambda x: get_prefinal_response(x), axis=1)
|
| 685 |
|
| 686 |
|