Spaces:
Running
on
Zero
Running
on
Zero
token-fix (#3)
Browse files- update tokens for app (c4634a054c973948c282bc181693556957d52580)
app.py
CHANGED
|
@@ -1,23 +1,25 @@
|
|
|
|
|
| 1 |
import os
|
| 2 |
from types import SimpleNamespace
|
|
|
|
| 3 |
import gradio as gr
|
| 4 |
-
import json
|
| 5 |
import pandas as pd
|
| 6 |
import spaces
|
| 7 |
import torch
|
| 8 |
-
from methods import gdc_api_calls, utilities
|
| 9 |
-
from transformers import AutoTokenizer, BertTokenizer, AutoModelForCausalLM, BertForSequenceClassification
|
| 10 |
from guidance import gen as guidance_gen
|
| 11 |
from guidance.models import Transformers
|
| 12 |
-
from transformers import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
from methods import gdc_api_calls, utilities
|
| 15 |
|
| 16 |
-
|
| 17 |
# set up various tokens
|
| 18 |
-
|
| 19 |
-
hf_TOKEN = os.environ.get("fineTest", False)
|
| 20 |
-
intent_token = os.environ.get("query_intent_test", False)
|
| 21 |
|
| 22 |
EXAMPLE_INPUTS = [
|
| 23 |
"What is the co-occurence frequency of somatic homozygous deletions in CDKN2A and CDKN2B in the mesothelioma project TCGA-MESO in the genomic data commons?",
|
|
@@ -26,7 +28,7 @@ EXAMPLE_INPUTS = [
|
|
| 26 |
"What fraction of cases have simple somatic mutations or copy number variants in ALK in Uterine Carcinosarcoma TCGA-UCS project in the genomic data commons?",
|
| 27 |
"How often is microsatellite instability observed in Stomach Adenocarcinoma TCGA-STAD project in the genomic data commons?",
|
| 28 |
"How often is the BRAF V600E mutation found in Skin Cutaneous Melanoma TCGA-SKCM project in the genomic data commons?",
|
| 29 |
-
"What is the co-occurence frequency of IDH1 R132H and TP53 R273C simple somatic mutations in the low grade glioma project TCGA-LGG in the genomic data commons?"
|
| 30 |
]
|
| 31 |
|
| 32 |
EXAMPLE_LABELS = [
|
|
@@ -36,22 +38,20 @@ EXAMPLE_LABELS = [
|
|
| 36 |
"copy number variants or somatic mutations",
|
| 37 |
"microsatellite-instability",
|
| 38 |
"simple somatic mutation",
|
| 39 |
-
"combination somatic mutations"
|
| 40 |
]
|
| 41 |
|
| 42 |
# set up requirements: models and data
|
| 43 |
print("getting gdc project information")
|
| 44 |
project_mappings = gdc_api_calls.get_gdc_project_ids(start=0, stop=86)
|
| 45 |
|
| 46 |
-
print(
|
| 47 |
-
model_id =
|
| 48 |
intent_tok = AutoTokenizer.from_pretrained(
|
| 49 |
-
model_id, trust_remote_code=True,
|
| 50 |
-
token=intent_token
|
| 51 |
)
|
| 52 |
-
intent_model = BertForSequenceClassification.from_pretrained(
|
| 53 |
-
|
| 54 |
-
intent_model = intent_model.to('cuda').eval()
|
| 55 |
|
| 56 |
|
| 57 |
print("loading gdc genes and mutations")
|
|
@@ -59,27 +59,15 @@ gdc_genes_mutations = utilities.load_gdc_genes_mutations_hf(hf_TOKEN)
|
|
| 59 |
|
| 60 |
print("loading llama-3B model and tokenizer")
|
| 61 |
model_id = "meta-llama/Llama-3.2-3B-Instruct"
|
| 62 |
-
tok = AutoTokenizer.from_pretrained(
|
| 63 |
-
model_id, trust_remote_code=True,
|
| 64 |
-
token=working_llama_token
|
| 65 |
-
)
|
| 66 |
model = AutoModelForCausalLM.from_pretrained(
|
| 67 |
-
model_id,
|
| 68 |
-
torch_dtype=torch.float16,
|
| 69 |
-
trust_remote_code=True,
|
| 70 |
-
token=working_llama_token
|
| 71 |
)
|
| 72 |
-
model = model.to(
|
| 73 |
|
| 74 |
|
| 75 |
# execute_api_call
|
| 76 |
-
def execute_api_call(
|
| 77 |
-
intent,
|
| 78 |
-
gene_entities,
|
| 79 |
-
mutation_entities,
|
| 80 |
-
cancer_entities,
|
| 81 |
-
query
|
| 82 |
-
):
|
| 83 |
if intent == "ssm_frequency":
|
| 84 |
result, cancer_entities = utilities.get_ssm_frequency(
|
| 85 |
gene_entities, mutation_entities, cancer_entities, project_mappings
|
|
@@ -123,7 +111,7 @@ def infer_user_intent(query):
|
|
| 123 |
"cnv_and_ssm": 4.0,
|
| 124 |
}
|
| 125 |
inputs = intent_tok(query, return_tensors="pt", truncation=True, padding=True)
|
| 126 |
-
inputs = {k: v.to(
|
| 127 |
outputs = intent_model(**inputs)
|
| 128 |
probs = torch.nn.functional.softmax(outputs.logits, dim=1)
|
| 129 |
predicted_label = torch.argmax(probs, dim=1).item()
|
|
@@ -172,11 +160,7 @@ def construct_and_execute_api_call(query):
|
|
| 172 |
print("user intent:\n{}\n".format(intent))
|
| 173 |
try:
|
| 174 |
api_call_result, cancer_entities = execute_api_call(
|
| 175 |
-
intent,
|
| 176 |
-
gene_entities,
|
| 177 |
-
mutation_entities,
|
| 178 |
-
cancer_entities,
|
| 179 |
-
query
|
| 180 |
)
|
| 181 |
print("api_call_result {}".format(api_call_result))
|
| 182 |
except Exception as e:
|
|
@@ -195,17 +179,11 @@ def construct_and_execute_api_call(query):
|
|
| 195 |
# generate llama model response
|
| 196 |
@spaces.GPU(duration=30)
|
| 197 |
def generate_response(modified_query):
|
| 198 |
-
#set_seed(1042)
|
| 199 |
regex = "The final answer is: \d*\.\d*%"
|
| 200 |
lm = Transformers(model=model, tokenizer=tok)
|
| 201 |
lm += modified_query
|
| 202 |
-
lm += guidance_gen(
|
| 203 |
-
"gen_response",
|
| 204 |
-
n=1,
|
| 205 |
-
temperature=0,
|
| 206 |
-
max_tokens=1000,
|
| 207 |
-
regex=regex
|
| 208 |
-
)
|
| 209 |
return lm["gen_response"]
|
| 210 |
|
| 211 |
|
|
@@ -256,10 +234,9 @@ def get_prefinal_response(row):
|
|
| 256 |
return pd.Series([modified_query, prefinal_llama_with_helper_output])
|
| 257 |
|
| 258 |
|
| 259 |
-
|
| 260 |
def execute_pipeline(question: str):
|
| 261 |
-
df = pd.DataFrame({
|
| 262 |
-
print(f
|
| 263 |
print("starting pipeline")
|
| 264 |
print("CUDA available:", torch.cuda.is_available())
|
| 265 |
print("CUDA device name:", torch.cuda.get_device_name(0))
|
|
@@ -278,9 +255,7 @@ def execute_pipeline(question: str):
|
|
| 278 |
] = df["questions"].apply(lambda x: batch_test(x))
|
| 279 |
df_exploded = df.explode("helper_output", ignore_index=True)
|
| 280 |
df_exploded[["modified_prompt", "pre_final_llama_with_helper_output"]] = (
|
| 281 |
-
df_exploded.apply(
|
| 282 |
-
lambda x: get_prefinal_response(x), axis=1
|
| 283 |
-
)
|
| 284 |
)
|
| 285 |
### postprocess response
|
| 286 |
print("postprocessing response")
|
|
@@ -296,31 +271,32 @@ def execute_pipeline(question: str):
|
|
| 296 |
"delta_final",
|
| 297 |
"final_response",
|
| 298 |
]
|
| 299 |
-
] = df_exploded.apply(
|
| 300 |
-
lambda x: utilities.postprocess_response(x), axis=1
|
| 301 |
-
)
|
| 302 |
final_columns = utilities.get_final_columns()
|
| 303 |
result = df_exploded[final_columns]
|
| 304 |
-
result.rename(
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
|
|
|
|
|
|
|
|
|
| 321 |
|
| 322 |
result = result.T.to_dict()
|
| 323 |
-
print(
|
| 324 |
|
| 325 |
result_string = ""
|
| 326 |
|
|
@@ -334,7 +310,6 @@ def execute_pipeline(question: str):
|
|
| 334 |
# return json.dumps(result.T.to_dict(), indent=2)
|
| 335 |
|
| 336 |
|
| 337 |
-
|
| 338 |
def visible_component(input_text):
|
| 339 |
return gr.update(value="WHATEVER")
|
| 340 |
|
|
@@ -349,16 +324,14 @@ with gr.Blocks(title="GDC QAG MCP server") as GDC_QAG_QUERY:
|
|
| 349 |
|
| 350 |
with gr.Row():
|
| 351 |
query_input = gr.Textbox(
|
| 352 |
-
lines
|
| 353 |
label="Search Query",
|
| 354 |
placeholder='e.g. "What is the co-occurence frequency of somatic homozygous deletions in CDKN2A and CDKN2B in the mesothelioma project TCGA-MESO in the genomic data commons?"',
|
| 355 |
info="Required: Enter your search query. Click on Examples to execute example queries. Please retry query if API is unavailable or connection aborts.",
|
| 356 |
)
|
| 357 |
|
| 358 |
gr.Examples(
|
| 359 |
-
examples=EXAMPLE_INPUTS,
|
| 360 |
-
inputs=query_input,
|
| 361 |
-
example_labels = EXAMPLE_LABELS
|
| 362 |
)
|
| 363 |
|
| 364 |
execute_button = gr.Button("Execute", variant="primary")
|
|
@@ -378,4 +351,4 @@ with gr.Blocks(title="GDC QAG MCP server") as GDC_QAG_QUERY:
|
|
| 378 |
|
| 379 |
|
| 380 |
if __name__ == "__main__":
|
| 381 |
-
GDC_QAG_QUERY.launch(mcp_server=True, show_api=True)
|
|
|
|
| 1 |
+
import json
|
| 2 |
import os
|
| 3 |
from types import SimpleNamespace
|
| 4 |
+
|
| 5 |
import gradio as gr
|
|
|
|
| 6 |
import pandas as pd
|
| 7 |
import spaces
|
| 8 |
import torch
|
|
|
|
|
|
|
| 9 |
from guidance import gen as guidance_gen
|
| 10 |
from guidance.models import Transformers
|
| 11 |
+
from transformers import (
|
| 12 |
+
AutoModelForCausalLM,
|
| 13 |
+
AutoTokenizer,
|
| 14 |
+
BertForSequenceClassification,
|
| 15 |
+
BertTokenizer,
|
| 16 |
+
set_seed,
|
| 17 |
+
)
|
| 18 |
|
| 19 |
from methods import gdc_api_calls, utilities
|
| 20 |
|
|
|
|
| 21 |
# set up various tokens
|
| 22 |
+
hf_TOKEN = os.environ.get("hf_svc_ctds", False)
|
|
|
|
|
|
|
| 23 |
|
| 24 |
EXAMPLE_INPUTS = [
|
| 25 |
"What is the co-occurence frequency of somatic homozygous deletions in CDKN2A and CDKN2B in the mesothelioma project TCGA-MESO in the genomic data commons?",
|
|
|
|
| 28 |
"What fraction of cases have simple somatic mutations or copy number variants in ALK in Uterine Carcinosarcoma TCGA-UCS project in the genomic data commons?",
|
| 29 |
"How often is microsatellite instability observed in Stomach Adenocarcinoma TCGA-STAD project in the genomic data commons?",
|
| 30 |
"How often is the BRAF V600E mutation found in Skin Cutaneous Melanoma TCGA-SKCM project in the genomic data commons?",
|
| 31 |
+
"What is the co-occurence frequency of IDH1 R132H and TP53 R273C simple somatic mutations in the low grade glioma project TCGA-LGG in the genomic data commons?",
|
| 32 |
]
|
| 33 |
|
| 34 |
EXAMPLE_LABELS = [
|
|
|
|
| 38 |
"copy number variants or somatic mutations",
|
| 39 |
"microsatellite-instability",
|
| 40 |
"simple somatic mutation",
|
| 41 |
+
"combination somatic mutations",
|
| 42 |
]
|
| 43 |
|
| 44 |
# set up requirements: models and data
|
| 45 |
print("getting gdc project information")
|
| 46 |
project_mappings = gdc_api_calls.get_gdc_project_ids(start=0, stop=86)
|
| 47 |
|
| 48 |
+
print("loading intent model and tokenizer")
|
| 49 |
+
model_id = "uc-ctds/query_intent"
|
| 50 |
intent_tok = AutoTokenizer.from_pretrained(
|
| 51 |
+
model_id, trust_remote_code=True, token=hf_TOKEN
|
|
|
|
| 52 |
)
|
| 53 |
+
intent_model = BertForSequenceClassification.from_pretrained(model_id, token=hf_TOKEN)
|
| 54 |
+
intent_model = intent_model.to("cuda").eval()
|
|
|
|
| 55 |
|
| 56 |
|
| 57 |
print("loading gdc genes and mutations")
|
|
|
|
| 59 |
|
| 60 |
print("loading llama-3B model and tokenizer")
|
| 61 |
model_id = "meta-llama/Llama-3.2-3B-Instruct"
|
| 62 |
+
tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, token=hf_TOKEN)
|
|
|
|
|
|
|
|
|
|
| 63 |
model = AutoModelForCausalLM.from_pretrained(
|
| 64 |
+
model_id, torch_dtype=torch.float16, trust_remote_code=True, token=hf_TOKEN
|
|
|
|
|
|
|
|
|
|
| 65 |
)
|
| 66 |
+
model = model.to("cuda").eval()
|
| 67 |
|
| 68 |
|
| 69 |
# execute_api_call
|
| 70 |
+
def execute_api_call(intent, gene_entities, mutation_entities, cancer_entities, query):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
if intent == "ssm_frequency":
|
| 72 |
result, cancer_entities = utilities.get_ssm_frequency(
|
| 73 |
gene_entities, mutation_entities, cancer_entities, project_mappings
|
|
|
|
| 111 |
"cnv_and_ssm": 4.0,
|
| 112 |
}
|
| 113 |
inputs = intent_tok(query, return_tensors="pt", truncation=True, padding=True)
|
| 114 |
+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
|
| 115 |
outputs = intent_model(**inputs)
|
| 116 |
probs = torch.nn.functional.softmax(outputs.logits, dim=1)
|
| 117 |
predicted_label = torch.argmax(probs, dim=1).item()
|
|
|
|
| 160 |
print("user intent:\n{}\n".format(intent))
|
| 161 |
try:
|
| 162 |
api_call_result, cancer_entities = execute_api_call(
|
| 163 |
+
intent, gene_entities, mutation_entities, cancer_entities, query
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
)
|
| 165 |
print("api_call_result {}".format(api_call_result))
|
| 166 |
except Exception as e:
|
|
|
|
| 179 |
# generate llama model response
|
| 180 |
@spaces.GPU(duration=30)
|
| 181 |
def generate_response(modified_query):
|
| 182 |
+
# set_seed(1042)
|
| 183 |
regex = "The final answer is: \d*\.\d*%"
|
| 184 |
lm = Transformers(model=model, tokenizer=tok)
|
| 185 |
lm += modified_query
|
| 186 |
+
lm += guidance_gen("gen_response", n=1, temperature=0, max_tokens=1000, regex=regex)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
return lm["gen_response"]
|
| 188 |
|
| 189 |
|
|
|
|
| 234 |
return pd.Series([modified_query, prefinal_llama_with_helper_output])
|
| 235 |
|
| 236 |
|
|
|
|
| 237 |
def execute_pipeline(question: str):
|
| 238 |
+
df = pd.DataFrame({"questions": [question]})
|
| 239 |
+
print(f"Question received: {question}")
|
| 240 |
print("starting pipeline")
|
| 241 |
print("CUDA available:", torch.cuda.is_available())
|
| 242 |
print("CUDA device name:", torch.cuda.get_device_name(0))
|
|
|
|
| 255 |
] = df["questions"].apply(lambda x: batch_test(x))
|
| 256 |
df_exploded = df.explode("helper_output", ignore_index=True)
|
| 257 |
df_exploded[["modified_prompt", "pre_final_llama_with_helper_output"]] = (
|
| 258 |
+
df_exploded.apply(lambda x: get_prefinal_response(x), axis=1)
|
|
|
|
|
|
|
| 259 |
)
|
| 260 |
### postprocess response
|
| 261 |
print("postprocessing response")
|
|
|
|
| 271 |
"delta_final",
|
| 272 |
"final_response",
|
| 273 |
]
|
| 274 |
+
] = df_exploded.apply(lambda x: utilities.postprocess_response(x), axis=1)
|
|
|
|
|
|
|
| 275 |
final_columns = utilities.get_final_columns()
|
| 276 |
result = df_exploded[final_columns]
|
| 277 |
+
result.rename(
|
| 278 |
+
columns={
|
| 279 |
+
"llama_base_output": "llama-3B baseline output",
|
| 280 |
+
"modified_prompt": "Query augmented prompt",
|
| 281 |
+
"helper_output": "Processed GDC API result",
|
| 282 |
+
"ground_truth_stat": "Ground truth frequency from GDC",
|
| 283 |
+
"llama_base_stat": "llama-3B baseline frequency",
|
| 284 |
+
"delta_llama": "llama-3B frequency - Ground truth frequency",
|
| 285 |
+
"final_response": "Query augmented generation",
|
| 286 |
+
"intent": "Intent",
|
| 287 |
+
"cancer_entities": "Cancer entities",
|
| 288 |
+
"gene_entities": "Gene entities",
|
| 289 |
+
"mutation_entities": "Mutation entities",
|
| 290 |
+
"questions": "Question",
|
| 291 |
+
},
|
| 292 |
+
inplace=True,
|
| 293 |
+
)
|
| 294 |
+
result.index = ["QAG pipeline results"] * len(result)
|
| 295 |
+
print("completed")
|
| 296 |
+
print("writing result string now")
|
| 297 |
|
| 298 |
result = result.T.to_dict()
|
| 299 |
+
print("result {}".format(result))
|
| 300 |
|
| 301 |
result_string = ""
|
| 302 |
|
|
|
|
| 310 |
# return json.dumps(result.T.to_dict(), indent=2)
|
| 311 |
|
| 312 |
|
|
|
|
| 313 |
def visible_component(input_text):
|
| 314 |
return gr.update(value="WHATEVER")
|
| 315 |
|
|
|
|
| 324 |
|
| 325 |
with gr.Row():
|
| 326 |
query_input = gr.Textbox(
|
| 327 |
+
lines=3,
|
| 328 |
label="Search Query",
|
| 329 |
placeholder='e.g. "What is the co-occurence frequency of somatic homozygous deletions in CDKN2A and CDKN2B in the mesothelioma project TCGA-MESO in the genomic data commons?"',
|
| 330 |
info="Required: Enter your search query. Click on Examples to execute example queries. Please retry query if API is unavailable or connection aborts.",
|
| 331 |
)
|
| 332 |
|
| 333 |
gr.Examples(
|
| 334 |
+
examples=EXAMPLE_INPUTS, inputs=query_input, example_labels=EXAMPLE_LABELS
|
|
|
|
|
|
|
| 335 |
)
|
| 336 |
|
| 337 |
execute_button = gr.Button("Execute", variant="primary")
|
|
|
|
| 351 |
|
| 352 |
|
| 353 |
if __name__ == "__main__":
|
| 354 |
+
GDC_QAG_QUERY.launch(mcp_server=True, show_api=True)
|