mlukowski commited on
Commit
7db7222
·
verified ·
1 Parent(s): 8e30203

token-fix (#3)

Browse files

- update tokens for app (c4634a054c973948c282bc181693556957d52580)

Files changed (1) hide show
  1. app.py +53 -80
app.py CHANGED
@@ -1,23 +1,25 @@
 
1
  import os
2
  from types import SimpleNamespace
 
3
  import gradio as gr
4
- import json
5
  import pandas as pd
6
  import spaces
7
  import torch
8
- from methods import gdc_api_calls, utilities
9
- from transformers import AutoTokenizer, BertTokenizer, AutoModelForCausalLM, BertForSequenceClassification
10
  from guidance import gen as guidance_gen
11
  from guidance.models import Transformers
12
- from transformers import set_seed
 
 
 
 
 
 
13
 
14
  from methods import gdc_api_calls, utilities
15
 
16
-
17
  # set up various tokens
18
- working_llama_token = os.environ.get("let_this_please_work", False)
19
- hf_TOKEN = os.environ.get("fineTest", False)
20
- intent_token = os.environ.get("query_intent_test", False)
21
 
22
  EXAMPLE_INPUTS = [
23
  "What is the co-occurence frequency of somatic homozygous deletions in CDKN2A and CDKN2B in the mesothelioma project TCGA-MESO in the genomic data commons?",
@@ -26,7 +28,7 @@ EXAMPLE_INPUTS = [
26
  "What fraction of cases have simple somatic mutations or copy number variants in ALK in Uterine Carcinosarcoma TCGA-UCS project in the genomic data commons?",
27
  "How often is microsatellite instability observed in Stomach Adenocarcinoma TCGA-STAD project in the genomic data commons?",
28
  "How often is the BRAF V600E mutation found in Skin Cutaneous Melanoma TCGA-SKCM project in the genomic data commons?",
29
- "What is the co-occurence frequency of IDH1 R132H and TP53 R273C simple somatic mutations in the low grade glioma project TCGA-LGG in the genomic data commons?"
30
  ]
31
 
32
  EXAMPLE_LABELS = [
@@ -36,22 +38,20 @@ EXAMPLE_LABELS = [
36
  "copy number variants or somatic mutations",
37
  "microsatellite-instability",
38
  "simple somatic mutation",
39
- "combination somatic mutations"
40
  ]
41
 
42
  # set up requirements: models and data
43
  print("getting gdc project information")
44
  project_mappings = gdc_api_calls.get_gdc_project_ids(start=0, stop=86)
45
 
46
- print('loading intent model and tokenizer')
47
- model_id = 'uc-ctds/query_intent'
48
  intent_tok = AutoTokenizer.from_pretrained(
49
- model_id, trust_remote_code=True,
50
- token=intent_token
51
  )
52
- intent_model = BertForSequenceClassification.from_pretrained(
53
- model_id, token=intent_token)
54
- intent_model = intent_model.to('cuda').eval()
55
 
56
 
57
  print("loading gdc genes and mutations")
@@ -59,27 +59,15 @@ gdc_genes_mutations = utilities.load_gdc_genes_mutations_hf(hf_TOKEN)
59
 
60
  print("loading llama-3B model and tokenizer")
61
  model_id = "meta-llama/Llama-3.2-3B-Instruct"
62
- tok = AutoTokenizer.from_pretrained(
63
- model_id, trust_remote_code=True,
64
- token=working_llama_token
65
- )
66
  model = AutoModelForCausalLM.from_pretrained(
67
- model_id,
68
- torch_dtype=torch.float16,
69
- trust_remote_code=True,
70
- token=working_llama_token
71
  )
72
- model = model.to('cuda').eval()
73
 
74
 
75
  # execute_api_call
76
- def execute_api_call(
77
- intent,
78
- gene_entities,
79
- mutation_entities,
80
- cancer_entities,
81
- query
82
- ):
83
  if intent == "ssm_frequency":
84
  result, cancer_entities = utilities.get_ssm_frequency(
85
  gene_entities, mutation_entities, cancer_entities, project_mappings
@@ -123,7 +111,7 @@ def infer_user_intent(query):
123
  "cnv_and_ssm": 4.0,
124
  }
125
  inputs = intent_tok(query, return_tensors="pt", truncation=True, padding=True)
126
- inputs = {k: v.to('cuda') for k, v in inputs.items()}
127
  outputs = intent_model(**inputs)
128
  probs = torch.nn.functional.softmax(outputs.logits, dim=1)
129
  predicted_label = torch.argmax(probs, dim=1).item()
@@ -172,11 +160,7 @@ def construct_and_execute_api_call(query):
172
  print("user intent:\n{}\n".format(intent))
173
  try:
174
  api_call_result, cancer_entities = execute_api_call(
175
- intent,
176
- gene_entities,
177
- mutation_entities,
178
- cancer_entities,
179
- query
180
  )
181
  print("api_call_result {}".format(api_call_result))
182
  except Exception as e:
@@ -195,17 +179,11 @@ def construct_and_execute_api_call(query):
195
  # generate llama model response
196
  @spaces.GPU(duration=30)
197
  def generate_response(modified_query):
198
- #set_seed(1042)
199
  regex = "The final answer is: \d*\.\d*%"
200
  lm = Transformers(model=model, tokenizer=tok)
201
  lm += modified_query
202
- lm += guidance_gen(
203
- "gen_response",
204
- n=1,
205
- temperature=0,
206
- max_tokens=1000,
207
- regex=regex
208
- )
209
  return lm["gen_response"]
210
 
211
 
@@ -256,10 +234,9 @@ def get_prefinal_response(row):
256
  return pd.Series([modified_query, prefinal_llama_with_helper_output])
257
 
258
 
259
-
260
  def execute_pipeline(question: str):
261
- df = pd.DataFrame({'questions' : [question]})
262
- print(f'Question received: {question}')
263
  print("starting pipeline")
264
  print("CUDA available:", torch.cuda.is_available())
265
  print("CUDA device name:", torch.cuda.get_device_name(0))
@@ -278,9 +255,7 @@ def execute_pipeline(question: str):
278
  ] = df["questions"].apply(lambda x: batch_test(x))
279
  df_exploded = df.explode("helper_output", ignore_index=True)
280
  df_exploded[["modified_prompt", "pre_final_llama_with_helper_output"]] = (
281
- df_exploded.apply(
282
- lambda x: get_prefinal_response(x), axis=1
283
- )
284
  )
285
  ### postprocess response
286
  print("postprocessing response")
@@ -296,31 +271,32 @@ def execute_pipeline(question: str):
296
  "delta_final",
297
  "final_response",
298
  ]
299
- ] = df_exploded.apply(
300
- lambda x: utilities.postprocess_response(x), axis=1
301
- )
302
  final_columns = utilities.get_final_columns()
303
  result = df_exploded[final_columns]
304
- result.rename(columns={
305
- 'llama_base_output': 'llama-3B baseline output',
306
- 'modified_prompt': 'Query augmented prompt',
307
- 'helper_output': 'Processed GDC API result',
308
- 'ground_truth_stat': 'Ground truth frequency from GDC',
309
- 'llama_base_stat': 'llama-3B baseline frequency',
310
- 'delta_llama': 'llama-3B frequency - Ground truth frequency',
311
- 'final_response': 'Query augmented generation',
312
- 'intent': 'Intent',
313
- 'cancer_entities': 'Cancer entities',
314
- 'gene_entities': 'Gene entities',
315
- 'mutation_entities': 'Mutation entities',
316
- 'questions' : 'Question'
317
- }, inplace=True)
318
- result.index = ['QAG pipeline results'] * len(result)
319
- print('completed')
320
- print('writing result string now')
 
 
 
321
 
322
  result = result.T.to_dict()
323
- print('result {}'.format(result))
324
 
325
  result_string = ""
326
 
@@ -334,7 +310,6 @@ def execute_pipeline(question: str):
334
  # return json.dumps(result.T.to_dict(), indent=2)
335
 
336
 
337
-
338
  def visible_component(input_text):
339
  return gr.update(value="WHATEVER")
340
 
@@ -349,16 +324,14 @@ with gr.Blocks(title="GDC QAG MCP server") as GDC_QAG_QUERY:
349
 
350
  with gr.Row():
351
  query_input = gr.Textbox(
352
- lines = 3,
353
  label="Search Query",
354
  placeholder='e.g. "What is the co-occurence frequency of somatic homozygous deletions in CDKN2A and CDKN2B in the mesothelioma project TCGA-MESO in the genomic data commons?"',
355
  info="Required: Enter your search query. Click on Examples to execute example queries. Please retry query if API is unavailable or connection aborts.",
356
  )
357
 
358
  gr.Examples(
359
- examples=EXAMPLE_INPUTS,
360
- inputs=query_input,
361
- example_labels = EXAMPLE_LABELS
362
  )
363
 
364
  execute_button = gr.Button("Execute", variant="primary")
@@ -378,4 +351,4 @@ with gr.Blocks(title="GDC QAG MCP server") as GDC_QAG_QUERY:
378
 
379
 
380
  if __name__ == "__main__":
381
- GDC_QAG_QUERY.launch(mcp_server=True, show_api=True)
 
1
+ import json
2
  import os
3
  from types import SimpleNamespace
4
+
5
  import gradio as gr
 
6
  import pandas as pd
7
  import spaces
8
  import torch
 
 
9
  from guidance import gen as guidance_gen
10
  from guidance.models import Transformers
11
+ from transformers import (
12
+ AutoModelForCausalLM,
13
+ AutoTokenizer,
14
+ BertForSequenceClassification,
15
+ BertTokenizer,
16
+ set_seed,
17
+ )
18
 
19
  from methods import gdc_api_calls, utilities
20
 
 
21
  # set up various tokens
22
+ hf_TOKEN = os.environ.get("hf_svc_ctds", False)
 
 
23
 
24
  EXAMPLE_INPUTS = [
25
  "What is the co-occurence frequency of somatic homozygous deletions in CDKN2A and CDKN2B in the mesothelioma project TCGA-MESO in the genomic data commons?",
 
28
  "What fraction of cases have simple somatic mutations or copy number variants in ALK in Uterine Carcinosarcoma TCGA-UCS project in the genomic data commons?",
29
  "How often is microsatellite instability observed in Stomach Adenocarcinoma TCGA-STAD project in the genomic data commons?",
30
  "How often is the BRAF V600E mutation found in Skin Cutaneous Melanoma TCGA-SKCM project in the genomic data commons?",
31
+ "What is the co-occurence frequency of IDH1 R132H and TP53 R273C simple somatic mutations in the low grade glioma project TCGA-LGG in the genomic data commons?",
32
  ]
33
 
34
  EXAMPLE_LABELS = [
 
38
  "copy number variants or somatic mutations",
39
  "microsatellite-instability",
40
  "simple somatic mutation",
41
+ "combination somatic mutations",
42
  ]
43
 
44
  # set up requirements: models and data
45
  print("getting gdc project information")
46
  project_mappings = gdc_api_calls.get_gdc_project_ids(start=0, stop=86)
47
 
48
+ print("loading intent model and tokenizer")
49
+ model_id = "uc-ctds/query_intent"
50
  intent_tok = AutoTokenizer.from_pretrained(
51
+ model_id, trust_remote_code=True, token=hf_TOKEN
 
52
  )
53
+ intent_model = BertForSequenceClassification.from_pretrained(model_id, token=hf_TOKEN)
54
+ intent_model = intent_model.to("cuda").eval()
 
55
 
56
 
57
  print("loading gdc genes and mutations")
 
59
 
60
  print("loading llama-3B model and tokenizer")
61
  model_id = "meta-llama/Llama-3.2-3B-Instruct"
62
+ tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, token=hf_TOKEN)
 
 
 
63
  model = AutoModelForCausalLM.from_pretrained(
64
+ model_id, torch_dtype=torch.float16, trust_remote_code=True, token=hf_TOKEN
 
 
 
65
  )
66
+ model = model.to("cuda").eval()
67
 
68
 
69
  # execute_api_call
70
+ def execute_api_call(intent, gene_entities, mutation_entities, cancer_entities, query):
 
 
 
 
 
 
71
  if intent == "ssm_frequency":
72
  result, cancer_entities = utilities.get_ssm_frequency(
73
  gene_entities, mutation_entities, cancer_entities, project_mappings
 
111
  "cnv_and_ssm": 4.0,
112
  }
113
  inputs = intent_tok(query, return_tensors="pt", truncation=True, padding=True)
114
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
115
  outputs = intent_model(**inputs)
116
  probs = torch.nn.functional.softmax(outputs.logits, dim=1)
117
  predicted_label = torch.argmax(probs, dim=1).item()
 
160
  print("user intent:\n{}\n".format(intent))
161
  try:
162
  api_call_result, cancer_entities = execute_api_call(
163
+ intent, gene_entities, mutation_entities, cancer_entities, query
 
 
 
 
164
  )
165
  print("api_call_result {}".format(api_call_result))
166
  except Exception as e:
 
179
  # generate llama model response
180
  @spaces.GPU(duration=30)
181
  def generate_response(modified_query):
182
+ # set_seed(1042)
183
  regex = "The final answer is: \d*\.\d*%"
184
  lm = Transformers(model=model, tokenizer=tok)
185
  lm += modified_query
186
+ lm += guidance_gen("gen_response", n=1, temperature=0, max_tokens=1000, regex=regex)
 
 
 
 
 
 
187
  return lm["gen_response"]
188
 
189
 
 
234
  return pd.Series([modified_query, prefinal_llama_with_helper_output])
235
 
236
 
 
237
  def execute_pipeline(question: str):
238
+ df = pd.DataFrame({"questions": [question]})
239
+ print(f"Question received: {question}")
240
  print("starting pipeline")
241
  print("CUDA available:", torch.cuda.is_available())
242
  print("CUDA device name:", torch.cuda.get_device_name(0))
 
255
  ] = df["questions"].apply(lambda x: batch_test(x))
256
  df_exploded = df.explode("helper_output", ignore_index=True)
257
  df_exploded[["modified_prompt", "pre_final_llama_with_helper_output"]] = (
258
+ df_exploded.apply(lambda x: get_prefinal_response(x), axis=1)
 
 
259
  )
260
  ### postprocess response
261
  print("postprocessing response")
 
271
  "delta_final",
272
  "final_response",
273
  ]
274
+ ] = df_exploded.apply(lambda x: utilities.postprocess_response(x), axis=1)
 
 
275
  final_columns = utilities.get_final_columns()
276
  result = df_exploded[final_columns]
277
+ result.rename(
278
+ columns={
279
+ "llama_base_output": "llama-3B baseline output",
280
+ "modified_prompt": "Query augmented prompt",
281
+ "helper_output": "Processed GDC API result",
282
+ "ground_truth_stat": "Ground truth frequency from GDC",
283
+ "llama_base_stat": "llama-3B baseline frequency",
284
+ "delta_llama": "llama-3B frequency - Ground truth frequency",
285
+ "final_response": "Query augmented generation",
286
+ "intent": "Intent",
287
+ "cancer_entities": "Cancer entities",
288
+ "gene_entities": "Gene entities",
289
+ "mutation_entities": "Mutation entities",
290
+ "questions": "Question",
291
+ },
292
+ inplace=True,
293
+ )
294
+ result.index = ["QAG pipeline results"] * len(result)
295
+ print("completed")
296
+ print("writing result string now")
297
 
298
  result = result.T.to_dict()
299
+ print("result {}".format(result))
300
 
301
  result_string = ""
302
 
 
310
  # return json.dumps(result.T.to_dict(), indent=2)
311
 
312
 
 
313
  def visible_component(input_text):
314
  return gr.update(value="WHATEVER")
315
 
 
324
 
325
  with gr.Row():
326
  query_input = gr.Textbox(
327
+ lines=3,
328
  label="Search Query",
329
  placeholder='e.g. "What is the co-occurence frequency of somatic homozygous deletions in CDKN2A and CDKN2B in the mesothelioma project TCGA-MESO in the genomic data commons?"',
330
  info="Required: Enter your search query. Click on Examples to execute example queries. Please retry query if API is unavailable or connection aborts.",
331
  )
332
 
333
  gr.Examples(
334
+ examples=EXAMPLE_INPUTS, inputs=query_input, example_labels=EXAMPLE_LABELS
 
 
335
  )
336
 
337
  execute_button = gr.Button("Execute", variant="primary")
 
351
 
352
 
353
  if __name__ == "__main__":
354
+ GDC_QAG_QUERY.launch(mcp_server=True, show_api=True)