aatu18 commited on
Commit
6894046
·
verified ·
1 Parent(s): ed05292

embeddings based project match, other optims

Browse files
Files changed (1) hide show
  1. app.py +63 -13
app.py CHANGED
@@ -4,6 +4,7 @@ import re
4
  from types import SimpleNamespace
5
 
6
  import gradio as gr
 
7
  import pandas as pd
8
  import spaces
9
  import spacy
@@ -18,6 +19,7 @@ from transformers import (
18
  BertTokenizer,
19
  set_seed,
20
  )
 
21
 
22
  from methods import gdc_api_calls, utilities
23
 
@@ -74,6 +76,10 @@ intent_tok = AutoTokenizer.from_pretrained(
74
  intent_model = BertForSequenceClassification.from_pretrained(model_id, token=hf_TOKEN)
75
  intent_model = intent_model.to("cuda").eval()
76
 
 
 
 
 
77
 
78
  print("loading gdc genes and mutations")
79
  gdc_genes_mutations = utilities.load_gdc_genes_mutations_hf(hf_TOKEN)
@@ -112,6 +118,17 @@ def infer_gene_entities_from_query(query):
112
 
113
 
114
 
 
 
 
 
 
 
 
 
 
 
 
115
  def check_if_project_id_in_query(query):
116
  # check if mention of project keys
117
  # e.g. TCGA-BRCA in query
@@ -147,11 +164,31 @@ def proj_id_and_partial_match(query, initial_cancer_entities):
147
 
148
 
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
  @utilities.timeit
152
  def postprocess_cancer_entities(initial_cancer_entities, query):
153
  # print('initial cancer entities {}'.format(initial_cancer_entities))
154
- # print('check if GDC project-id mentioned in query')
155
  final_entities = check_if_project_id_in_query(query)
156
  if final_entities:
157
  return final_entities
@@ -171,6 +208,13 @@ def postprocess_cancer_entities(initial_cancer_entities, query):
171
  final_entities = proj_id_and_partial_match(
172
  query, initial_cancer_entities
173
  )
 
 
 
 
 
 
 
174
  else:
175
  # no initial_cancer_entities
176
  # check project_mappings keys/values for matches with query terms
@@ -400,28 +444,20 @@ def batch_test(query):
400
  def get_prefinal_response(row):
401
  try:
402
  query = row["questions"]
403
- genes = ','.join(row['gene_entities'])
404
  gdc_result = row["gdc_result"]
405
  except Exception as e:
406
  print(f"unable to retrieve query: {query} or gdc_result: {gdc_result}")
407
 
408
- intent = intent_expansion[row['intent']]
409
-
410
- print("\nStep 6: Construct LLM prompts for llama-3B\n")
411
- descriptive_prompt = construct_modified_query_description(genes, intent)
412
  percentage_prompt = construct_modified_query_percentage(query, gdc_result)
413
 
414
- print("\nStep 7: Generate LLM response R on query augmented prompts\n")
415
- descriptive_response = generate_descriptive_response(descriptive_prompt)
416
- if not descriptive_response.endswith('.'):
417
- descriptive_response += '.'
418
 
419
  percentage_response = generate_percentage_response(percentage_prompt)
420
  percentage_response = re.sub(
421
  r'final response', 'frequency for your query', percentage_response)
422
  return pd.Series([
423
- descriptive_prompt, percentage_prompt,
424
- descriptive_response, percentage_response
425
  ])
426
 
427
 
@@ -630,7 +666,21 @@ def execute_pipeline(question: str):
630
  ]
631
  ] = df["questions"].apply(lambda x: batch_test(x))
632
  df_exploded = df.explode("gdc_result", ignore_index=True)
633
- df_exploded[["descriptive_prompt", "percentage_prompt", "descriptive_response", "percentage_response"]] = df_exploded.apply(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
634
  lambda x: get_prefinal_response(x), axis=1)
635
 
636
 
 
4
  from types import SimpleNamespace
5
 
6
  import gradio as gr
7
+ from itertools import chain
8
  import pandas as pd
9
  import spaces
10
  import spacy
 
19
  BertTokenizer,
20
  set_seed,
21
  )
22
+ from sentence_transformers import SentenceTransformer, util
23
 
24
  from methods import gdc_api_calls, utilities
25
 
 
76
  intent_model = BertForSequenceClassification.from_pretrained(model_id, token=hf_TOKEN)
77
  intent_model = intent_model.to("cuda").eval()
78
 
79
+ # load sentence transformer model to test cancer embeddings
80
+ st_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
81
+ st_model = st_model.to("cuda")
82
+
83
 
84
  print("loading gdc genes and mutations")
85
  gdc_genes_mutations = utilities.load_gdc_genes_mutations_hf(hf_TOKEN)
 
118
 
119
 
120
 
121
+ def get_project_embeddings():
122
+ project_rows = []
123
+ for k,v in project_mappings.items():
124
+ new_v = [item.replace(',', '') for item in v]
125
+ combined = ','.join([k] + new_v)
126
+ project_rows.append(combined)
127
+ row_embeddings = model.encode(project_rows, convert_to_tensor=True)
128
+ return project_rows, row_embeddings
129
+
130
+
131
+
132
  def check_if_project_id_in_query(query):
133
  # check if mention of project keys
134
  # e.g. TCGA-BRCA in query
 
164
 
165
 
166
 
167
+ def get_top_k_cancer_entities(query, row_embeddings, project_rows, top_k=20):
168
+ top_cancer_entities = []
169
+ query_embedding = st_model.encode(query, convert_to_tensor=True)
170
+ cosine_scores = util.cos_sim(query_embedding, row_embeddings)
171
+ top_results = torch.topk(cosine_scores, k=top_k)
172
+ top_results_indices = top_results.indices.tolist()
173
+ top_results_scores = top_results.values.tolist()
174
+ print(top_results_scores)
175
+ for idx, score in enumerate(top_results_scores[0]):
176
+ if score > 0.5:
177
+ row_idx = top_results_indices[0][idx]
178
+ print('best row, score: {} {}'.format(project_rows[row_idx], score))
179
+ top_cancer_entities.append([project_rows[row_idx], score])
180
+ try:
181
+ top_projects = [sublist[0].split(',')[0] for sublist in top_cancer_entities]
182
+ except Exception as e:
183
+ top_projects = []
184
+ return top_projects
185
+
186
+
187
 
188
  @utilities.timeit
189
  def postprocess_cancer_entities(initial_cancer_entities, query):
190
  # print('initial cancer entities {}'.format(initial_cancer_entities))
191
+ project_rows, row_embeddings = get_project_embeddings()
192
  final_entities = check_if_project_id_in_query(query)
193
  if final_entities:
194
  return final_entities
 
208
  final_entities = proj_id_and_partial_match(
209
  query, initial_cancer_entities
210
  )
211
+ # try embedding based match
212
+ if not final_entities:
213
+ print('Test embedding based match')
214
+ for i in initial_cancer_entities:
215
+ c_entities = get_top_k_cancer_entities(i, row_embeddings, project_rows)
216
+ final_entities.append(c_entities)
217
+ final_entities = list(chain.from_iterable(final_entities))
218
  else:
219
  # no initial_cancer_entities
220
  # check project_mappings keys/values for matches with query terms
 
444
  def get_prefinal_response(row):
445
  try:
446
  query = row["questions"]
 
447
  gdc_result = row["gdc_result"]
448
  except Exception as e:
449
  print(f"unable to retrieve query: {query} or gdc_result: {gdc_result}")
450
 
451
+ print("\nStep 6: Construct LLM prompts (percentage) for llama-3B\n")
 
 
 
452
  percentage_prompt = construct_modified_query_percentage(query, gdc_result)
453
 
454
+ print("\nStep 7: Generate LLM response R (percentage) on query augmented prompts\n")
 
 
 
455
 
456
  percentage_response = generate_percentage_response(percentage_prompt)
457
  percentage_response = re.sub(
458
  r'final response', 'frequency for your query', percentage_response)
459
  return pd.Series([
460
+ percentage_prompt, percentage_response
 
461
  ])
462
 
463
 
 
666
  ]
667
  ] = df["questions"].apply(lambda x: batch_test(x))
668
  df_exploded = df.explode("gdc_result", ignore_index=True)
669
+
670
+ # generate descriptive response once based on genes and intent
671
+ print("\nStep 6: Construct LLM prompts (descriptive) for llama-3B\n")
672
+ intent = intent_expansion[df['intent'].iloc[0]]
673
+ genes = ','.join(df['gene_entities'].iloc[0])
674
+ descriptive_prompt = construct_modified_query_description(genes, intent)
675
+
676
+ print("\nStep 7: Generate LLM response R (descriptive) on query augmented prompts\n")
677
+ descriptive_response = generate_descriptive_response(descriptive_prompt, model, tok)
678
+ if not descriptive_response.endswith('.'):
679
+ descriptive_response += '.'
680
+
681
+ df_exploded[['descriptive_prompt', 'descriptive_response']] = descriptive_prompt, descriptive_response
682
+
683
+ df_exploded[["percentage_prompt", "percentage_response"]] = df_exploded.apply(
684
  lambda x: get_prefinal_response(x), axis=1)
685
 
686