aatu18 commited on
Commit
968854a
·
verified ·
1 Parent(s): 47785b8

move get embeddings inside cancer entity processing

Browse files
Files changed (1) hide show
  1. app.py +9 -11
app.py CHANGED
@@ -131,15 +131,6 @@ def get_project_embeddings():
131
  return project_rows, row_embeddings
132
 
133
 
134
- # get project embeddings
135
- print('loading cancer embeddings')
136
- project_rows, row_embeddings = get_project_embeddings()
137
- print(f"row_embeddings: {row_embeddings}")
138
- print(f"row embeddings device {row_embeddings.device}")
139
- print(f"project rows: {project_rows}")
140
- print(f"project rows device {project_rows.device}")
141
-
142
-
143
 
144
  def check_if_project_id_in_query(query):
145
  # check if mention of project keys
@@ -176,7 +167,7 @@ def proj_id_and_partial_match(query, initial_cancer_entities):
176
 
177
 
178
  @spaces.GPU(duration=15)
179
- def get_top_k_cancer_entities(query, top_k=20):
180
  top_cancer_entities = []
181
  query_embedding = st_model.encode(query, convert_to_tensor=True, device='cuda')
182
  print(f"query embedding is on device: {query_embedding.device}")
@@ -203,6 +194,13 @@ def get_top_k_cancer_entities(query, top_k=20):
203
  @utilities.timeit
204
  def postprocess_cancer_entities(initial_cancer_entities, query):
205
  # print('initial cancer entities {}'.format(initial_cancer_entities))
 
 
 
 
 
 
 
206
  final_entities = check_if_project_id_in_query(query)
207
  if final_entities:
208
  return final_entities
@@ -226,7 +224,7 @@ def postprocess_cancer_entities(initial_cancer_entities, query):
226
  if not final_entities:
227
  print('Test embedding based match')
228
  for i in initial_cancer_entities:
229
- c_entities = get_top_k_cancer_entities(i)
230
  final_entities.append(c_entities)
231
  final_entities = list(chain.from_iterable(final_entities))
232
  else:
 
131
  return project_rows, row_embeddings
132
 
133
 
 
 
 
 
 
 
 
 
 
134
 
135
  def check_if_project_id_in_query(query):
136
  # check if mention of project keys
 
167
 
168
 
169
  @spaces.GPU(duration=15)
170
+ def get_top_k_cancer_entities(query, project_rows, row_embeddings, top_k=20):
171
  top_cancer_entities = []
172
  query_embedding = st_model.encode(query, convert_to_tensor=True, device='cuda')
173
  print(f"query embedding is on device: {query_embedding.device}")
 
194
  @utilities.timeit
195
  def postprocess_cancer_entities(initial_cancer_entities, query):
196
  # print('initial cancer entities {}'.format(initial_cancer_entities))
197
+ # get project embeddings
198
+ print('loading cancer embeddings')
199
+ project_rows, row_embeddings = get_project_embeddings()
200
+ print(f"row_embeddings: {row_embeddings}")
201
+ print(f"row embeddings device {row_embeddings.device}")
202
+ print(f"project rows: {project_rows}")
203
+ print(f"project rows device {project_rows.device}")
204
  final_entities = check_if_project_id_in_query(query)
205
  if final_entities:
206
  return final_entities
 
224
  if not final_entities:
225
  print('Test embedding based match')
226
  for i in initial_cancer_entities:
227
+ c_entities = get_top_k_cancer_entities(i, project_rows, row_embeddings)
228
  final_entities.append(c_entities)
229
  final_entities = list(chain.from_iterable(final_entities))
230
  else: