Spaces:
Sleeping
Sleeping
move get embeddings inside cancer entity processing
Browse files
app.py
CHANGED
|
@@ -131,15 +131,6 @@ def get_project_embeddings():
|
|
| 131 |
return project_rows, row_embeddings
|
| 132 |
|
| 133 |
|
| 134 |
-
# get project embeddings
|
| 135 |
-
print('loading cancer embeddings')
|
| 136 |
-
project_rows, row_embeddings = get_project_embeddings()
|
| 137 |
-
print(f"row_embeddings: {row_embeddings}")
|
| 138 |
-
print(f"row embeddings device {row_embeddings.device}")
|
| 139 |
-
print(f"project rows: {project_rows}")
|
| 140 |
-
print(f"project rows device {project_rows.device}")
|
| 141 |
-
|
| 142 |
-
|
| 143 |
|
| 144 |
def check_if_project_id_in_query(query):
|
| 145 |
# check if mention of project keys
|
|
@@ -176,7 +167,7 @@ def proj_id_and_partial_match(query, initial_cancer_entities):
|
|
| 176 |
|
| 177 |
|
| 178 |
@spaces.GPU(duration=15)
|
| 179 |
-
def get_top_k_cancer_entities(query, top_k=20):
|
| 180 |
top_cancer_entities = []
|
| 181 |
query_embedding = st_model.encode(query, convert_to_tensor=True, device='cuda')
|
| 182 |
print(f"query embedding is on device: {query_embedding.device}")
|
|
@@ -203,6 +194,13 @@ def get_top_k_cancer_entities(query, top_k=20):
|
|
| 203 |
@utilities.timeit
|
| 204 |
def postprocess_cancer_entities(initial_cancer_entities, query):
|
| 205 |
# print('initial cancer entities {}'.format(initial_cancer_entities))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
final_entities = check_if_project_id_in_query(query)
|
| 207 |
if final_entities:
|
| 208 |
return final_entities
|
|
@@ -226,7 +224,7 @@ def postprocess_cancer_entities(initial_cancer_entities, query):
|
|
| 226 |
if not final_entities:
|
| 227 |
print('Test embedding based match')
|
| 228 |
for i in initial_cancer_entities:
|
| 229 |
-
c_entities = get_top_k_cancer_entities(i)
|
| 230 |
final_entities.append(c_entities)
|
| 231 |
final_entities = list(chain.from_iterable(final_entities))
|
| 232 |
else:
|
|
|
|
| 131 |
return project_rows, row_embeddings
|
| 132 |
|
| 133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
def check_if_project_id_in_query(query):
|
| 136 |
# check if mention of project keys
|
|
|
|
| 167 |
|
| 168 |
|
| 169 |
@spaces.GPU(duration=15)
|
| 170 |
+
def get_top_k_cancer_entities(query, project_rows, row_embeddings, top_k=20):
|
| 171 |
top_cancer_entities = []
|
| 172 |
query_embedding = st_model.encode(query, convert_to_tensor=True, device='cuda')
|
| 173 |
print(f"query embedding is on device: {query_embedding.device}")
|
|
|
|
| 194 |
@utilities.timeit
|
| 195 |
def postprocess_cancer_entities(initial_cancer_entities, query):
|
| 196 |
# print('initial cancer entities {}'.format(initial_cancer_entities))
|
| 197 |
+
# get project embeddings
|
| 198 |
+
print('loading cancer embeddings')
|
| 199 |
+
project_rows, row_embeddings = get_project_embeddings()
|
| 200 |
+
print(f"row_embeddings: {row_embeddings}")
|
| 201 |
+
print(f"row embeddings device {row_embeddings.device}")
|
| 202 |
+
print(f"project rows: {project_rows}")
|
| 203 |
+
print(f"project rows device {project_rows.device}")
|
| 204 |
final_entities = check_if_project_id_in_query(query)
|
| 205 |
if final_entities:
|
| 206 |
return final_entities
|
|
|
|
| 224 |
if not final_entities:
|
| 225 |
print('Test embedding based match')
|
| 226 |
for i in initial_cancer_entities:
|
| 227 |
+
c_entities = get_top_k_cancer_entities(i, project_rows, row_embeddings)
|
| 228 |
final_entities.append(c_entities)
|
| 229 |
final_entities = list(chain.from_iterable(final_entities))
|
| 230 |
else:
|