Spaces:
Sleeping
Sleeping
minor updates to results colnames, remove unwanted prints
Browse files
app.py
CHANGED
|
@@ -122,7 +122,7 @@ def infer_user_intent(query):
|
|
| 122 |
|
| 123 |
# function to combine entities, intent and API call
|
| 124 |
def construct_and_execute_api_call(query):
|
| 125 |
-
print("query:\n{}\n".format(query))
|
| 126 |
# Infer entities
|
| 127 |
initial_cancer_entities = utilities.return_initial_cancer_entities(
|
| 128 |
query, model="en_ner_bc5cdr_md"
|
|
@@ -188,9 +188,9 @@ def generate_response(modified_query):
|
|
| 188 |
|
| 189 |
def batch_test(query):
|
| 190 |
modified_query = utilities.construct_modified_query_base_llm(query)
|
| 191 |
-
print(f"
|
| 192 |
llama_base_output = generate_response(modified_query)
|
| 193 |
-
print(f"
|
| 194 |
try:
|
| 195 |
result = construct_and_execute_api_call(query)
|
| 196 |
except Exception as e:
|
|
@@ -229,7 +229,7 @@ def get_prefinal_response(row):
|
|
| 229 |
except Exception as e:
|
| 230 |
print(f"unable to retrieve query: {query} or helper_output: {helper_output}")
|
| 231 |
modified_query = utilities.construct_modified_query(query, helper_output)
|
| 232 |
-
print('generate LLM response')
|
| 233 |
prefinal_llama_with_helper_output = generate_response(modified_query)
|
| 234 |
return pd.Series([modified_query, prefinal_llama_with_helper_output])
|
| 235 |
|
|
@@ -237,12 +237,9 @@ def get_prefinal_response(row):
|
|
| 237 |
def execute_pipeline(question: str):
|
| 238 |
df = pd.DataFrame({"questions": [question]})
|
| 239 |
print(f"Question received: {question}")
|
| 240 |
-
print("starting pipeline")
|
| 241 |
-
print("CUDA available:", torch.cuda.is_available())
|
| 242 |
print("CUDA device name:", torch.cuda.get_device_name(0))
|
| 243 |
|
| 244 |
# queries input file
|
| 245 |
-
print(f"running test on input {df}")
|
| 246 |
df[
|
| 247 |
[
|
| 248 |
"llama_base_output",
|
|
@@ -273,12 +270,12 @@ def execute_pipeline(question: str):
|
|
| 273 |
]
|
| 274 |
] = df_exploded.apply(lambda x: utilities.postprocess_response(x), axis=1)
|
| 275 |
final_columns = utilities.get_final_columns()
|
| 276 |
-
result = df_exploded[final_columns]
|
| 277 |
result.rename(
|
| 278 |
columns={
|
| 279 |
"llama_base_output": "llama-3B baseline output",
|
| 280 |
"modified_prompt": "Query augmented prompt",
|
| 281 |
-
"helper_output": "
|
| 282 |
"ground_truth_stat": "Ground truth frequency from GDC",
|
| 283 |
"llama_base_stat": "llama-3B baseline frequency",
|
| 284 |
"delta_llama": "llama-3B frequency - Ground truth frequency",
|
|
@@ -291,7 +288,7 @@ def execute_pipeline(question: str):
|
|
| 291 |
},
|
| 292 |
inplace=True,
|
| 293 |
)
|
| 294 |
-
result.index = ["QAG
|
| 295 |
print("completed")
|
| 296 |
print("writing result string now")
|
| 297 |
|
|
@@ -300,10 +297,10 @@ def execute_pipeline(question: str):
|
|
| 300 |
|
| 301 |
result_string = ""
|
| 302 |
|
| 303 |
-
result_string += f"Question: {result['QAG
|
| 304 |
-
result_string += f"llama-3B baseline output: {result['QAG
|
| 305 |
-
result_string += f"Query augmented prompt: {result['QAG
|
| 306 |
-
result_string += f"Query augmented generation: {result['QAG
|
| 307 |
|
| 308 |
return result_string
|
| 309 |
|
|
|
|
| 122 |
|
| 123 |
# function to combine entities, intent and API call
|
| 124 |
def construct_and_execute_api_call(query):
|
| 125 |
+
print("starting GDC-QAG on query:\n{}\n".format(query))
|
| 126 |
# Infer entities
|
| 127 |
initial_cancer_entities = utilities.return_initial_cancer_entities(
|
| 128 |
query, model="en_ner_bc5cdr_md"
|
|
|
|
| 188 |
|
| 189 |
def batch_test(query):
|
| 190 |
modified_query = utilities.construct_modified_query_base_llm(query)
|
| 191 |
+
print(f"Obtain baseline llama-3B response on modified query: {modified_query}")
|
| 192 |
llama_base_output = generate_response(modified_query)
|
| 193 |
+
print(f"llama-3B baseline response: {llama_base_output}")
|
| 194 |
try:
|
| 195 |
result = construct_and_execute_api_call(query)
|
| 196 |
except Exception as e:
|
|
|
|
| 229 |
except Exception as e:
|
| 230 |
print(f"unable to retrieve query: {query} or helper_output: {helper_output}")
|
| 231 |
modified_query = utilities.construct_modified_query(query, helper_output)
|
| 232 |
+
print('generate LLM response on query augmented prompt')
|
| 233 |
prefinal_llama_with_helper_output = generate_response(modified_query)
|
| 234 |
return pd.Series([modified_query, prefinal_llama_with_helper_output])
|
| 235 |
|
|
|
|
| 237 |
def execute_pipeline(question: str):
|
| 238 |
df = pd.DataFrame({"questions": [question]})
|
| 239 |
print(f"Question received: {question}")
|
|
|
|
|
|
|
| 240 |
print("CUDA device name:", torch.cuda.get_device_name(0))
|
| 241 |
|
| 242 |
# queries input file
|
|
|
|
| 243 |
df[
|
| 244 |
[
|
| 245 |
"llama_base_output",
|
|
|
|
| 270 |
]
|
| 271 |
] = df_exploded.apply(lambda x: utilities.postprocess_response(x), axis=1)
|
| 272 |
final_columns = utilities.get_final_columns()
|
| 273 |
+
result = df_exploded[final_columns].copy()
|
| 274 |
result.rename(
|
| 275 |
columns={
|
| 276 |
"llama_base_output": "llama-3B baseline output",
|
| 277 |
"modified_prompt": "Query augmented prompt",
|
| 278 |
+
"helper_output": "GDC Result",
|
| 279 |
"ground_truth_stat": "Ground truth frequency from GDC",
|
| 280 |
"llama_base_stat": "llama-3B baseline frequency",
|
| 281 |
"delta_llama": "llama-3B frequency - Ground truth frequency",
|
|
|
|
| 288 |
},
|
| 289 |
inplace=True,
|
| 290 |
)
|
| 291 |
+
result.index = ["GDC-QAG results"] * len(result)
|
| 292 |
print("completed")
|
| 293 |
print("writing result string now")
|
| 294 |
|
|
|
|
| 297 |
|
| 298 |
result_string = ""
|
| 299 |
|
| 300 |
+
result_string += f"Question: {result['GDC-QAG results']['Question']}\n"
|
| 301 |
+
result_string += f"llama-3B baseline output: {result['GDC-QAG results']['llama-3B baseline frequency']}%\n"
|
| 302 |
+
result_string += f"Query augmented prompt: {result['GDC-QAG results']['Query augmented prompt']}"
|
| 303 |
+
result_string += f"Query augmented generation: {result['GDC-QAG results']['Query augmented generation']}"
|
| 304 |
|
| 305 |
return result_string
|
| 306 |
|