Michael commited on
Commit
64da2e4
·
1 Parent(s): 21e4c5e

add app and intial commit

Browse files
Files changed (4) hide show
  1. app.py +64 -0
  2. gdc_pipeline.py +359 -0
  3. poetry.lock +0 -0
  4. pyproject.toml +48 -0
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ from gdc_pipeline import execute_pipeline, setup_args, setup_models_and_data
4
+
5
+ # setup models and data
6
+ qag_requirements = setup_models_and_data()
7
+
8
+ # question = 'What is the co-occurence frequency of somatic homozygous deletions in CDKN2A and CDKN2B in the mesothelioma project TCGA-MESO in the genomic data commons?'
9
+
10
+ def wrapped_execute_pipeline(question: str):
11
+ df = pd.DataFrame({'questions' : [question]})
12
+ print(f'Question received: {question}')
13
+ try:
14
+ result = execute_pipeline(
15
+ df,
16
+ qag_requirements.gdc_genes_mutations,
17
+ qag_requirements.model,
18
+ qag_requirements.tok,
19
+ qag_requirements.intent_model,
20
+ qag_requirements.intent_tok,
21
+ qag_requirements.project_mappings,
22
+ output_file_prefix=None
23
+ )
24
+ except Exception as e:
25
+ result = 'Unable to execute GDC API, can you please retry with a template question?'
26
+ return result
27
+
28
+ def visible_component(input_text):
29
+ return gr.update(value="WHATEVER")
30
+
31
+
32
+ # Create Gradio interface
33
+ with gr.Blocks(title="GDC QAG MCP server") as demo:
34
+ gr.Markdown(
35
+ """
36
+ # GDC QAG Service
37
+ """
38
+ )
39
+
40
+ with gr.Row():
41
+ query_input = gr.Textbox(
42
+ lines = 3,
43
+ label="Search Query",
44
+ placeholder='e.g. "What is the co-occurence frequency of somatic homozygous deletions in CDKN2A and CDKN2B in the mesothelioma project TCGA-MESO in the genomic data commons?"',
45
+ info="Required: Enter your search query",
46
+ )
47
+
48
+ search_button = gr.Button("Search", variant="primary")
49
+
50
+ output = gr.Textbox(
51
+ label="Query Result",
52
+ lines=10,
53
+ max_lines=25,
54
+ info="The Result of the Query will appear here",
55
+ )
56
+
57
+ search_button.click(
58
+ fn=wrapped_execute_pipeline,
59
+ inputs=[query_input],
60
+ outputs=output,
61
+ )
62
+
63
+ if __name__ == "__main__":
64
+ GDC_QAG_QUERY.launch(mcp_server=True, show_api=True)
gdc_pipeline.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # QAG pipeline entry point script
3
+
4
+
5
+ import argparse
6
+ import os
7
+ from types import SimpleNamespace
8
+
9
+ import pandas as pd
10
+ import spaces
11
+ from guidance import gen as guidance_gen
12
+ from guidance.models import Transformers
13
+ from tqdm import tqdm
14
+ from transformers import set_seed
15
+
16
+ from methods import gdc_api_calls, utilities
17
+
18
+ tqdm.pandas()
19
+
20
+
21
+ def execute_api_call(
22
+ intent,
23
+ gene_entities,
24
+ mutation_entities,
25
+ cancer_entities,
26
+ query,
27
+ gdc_genes_mutations,
28
+ project_mappings,
29
+ ):
30
+ if intent == "ssm_frequency":
31
+ result, cancer_entities = utilities.get_ssm_frequency(
32
+ gene_entities, mutation_entities, cancer_entities, project_mappings
33
+ )
34
+ elif intent == "top_mutated_genes_by_project":
35
+ result = gdc_api_calls.get_top_mutated_genes_by_project(
36
+ cancer_entities, top_k=10
37
+ )
38
+ elif intent == "most_frequently_mutated_gene":
39
+ result = gdc_api_calls.get_top_mutated_genes_by_project(
40
+ cancer_entities, top_k=1
41
+ )
42
+ elif intent == "freq_cnv_loss_or_gain":
43
+ result, cancer_entities = gdc_api_calls.get_freq_cnv_loss_or_gain(
44
+ gene_entities, cancer_entities, query, cnv_and_ssm_flag=False
45
+ )
46
+ elif intent == "msi_h_frequency":
47
+ result, cancer_entities = gdc_api_calls.get_msi_frequency(cancer_entities)
48
+ elif intent == "cnv_and_ssm":
49
+ result, cancer_entities = utilities.get_freq_of_cnv_and_ssms(
50
+ query, cancer_entities, gene_entities, gdc_genes_mutations
51
+ )
52
+ elif intent == "top_cases_counts_by_gene":
53
+ result, cancer_entities = gdc_api_calls.get_top_cases_counts_by_gene(
54
+ gene_entities, cancer_entities
55
+ )
56
+ elif intent == "project_summary":
57
+ result = gdc_api_calls.get_project_summary(cancer_entities)
58
+ else:
59
+ result = "user intent not recognized, or use case not covered"
60
+ return result, cancer_entities
61
+
62
+
63
+ # function to combine entities, intent and API call
64
+ def construct_and_execute_api_call(
65
+ query, gdc_genes_mutations, project_mappings, intent_model, intent_tok
66
+ ):
67
+ print("query:\n{}\n".format(query))
68
+ # Infer entities
69
+ initial_cancer_entities = utilities.return_initial_cancer_entities(
70
+ query, model="en_ner_bc5cdr_md"
71
+ )
72
+
73
+ if not initial_cancer_entities:
74
+ try:
75
+ initial_cancer_entities = utilities.return_initial_cancer_entities(
76
+ query, model="en_core_sci_md"
77
+ )
78
+ except Exception as e:
79
+ print("unable to guess cancer entities {}".format(str(e)))
80
+ initial_cancer_entities = []
81
+
82
+ cancer_entities = utilities.postprocess_cancer_entities(
83
+ project_mappings, initial_cancer_entities=initial_cancer_entities, query=query
84
+ )
85
+
86
+ # if cancer entities is empty from above methods
87
+ # return all projects
88
+ if not cancer_entities:
89
+ cancer_entities = list(project_mappings.keys())
90
+ gene_entities = utilities.infer_gene_entities_from_query(query, gdc_genes_mutations)
91
+ mutation_entities = utilities.infer_mutation_entities(
92
+ gene_entities=gene_entities,
93
+ query=query,
94
+ gdc_genes_mutations=gdc_genes_mutations,
95
+ )
96
+
97
+ print("gene entities {}".format(gene_entities))
98
+ print("mutation entities {}".format(mutation_entities))
99
+ print("cancer entities {}".format(cancer_entities))
100
+
101
+ # infer user intent
102
+ intent = utilities.infer_user_intent(query, intent_model, intent_tok)
103
+ print("user intent:\n{}\n".format(intent))
104
+ try:
105
+ api_call_result, cancer_entities = execute_api_call(
106
+ intent,
107
+ gene_entities,
108
+ mutation_entities,
109
+ cancer_entities,
110
+ query,
111
+ gdc_genes_mutations,
112
+ project_mappings,
113
+ )
114
+ print("api_call_result {}".format(api_call_result))
115
+ # print('cancer_entities {}'.format(cancer_entities))
116
+ except Exception as e:
117
+ print("unable to process query {} {}".format(query, str(e)))
118
+ api_call_result = []
119
+ cancer_entities = []
120
+ return SimpleNamespace(
121
+ helper_output=api_call_result,
122
+ cancer_entities=cancer_entities,
123
+ intent=intent,
124
+ gene_entities=gene_entities,
125
+ mutation_entities=mutation_entities,
126
+ )
127
+
128
+
129
+ # generate llama model response
130
+ @spaces.GPU(duration=60)
131
+ def generate_response(modified_query, model, tok):
132
+ set_seed(1042)
133
+ regex = "The final answer is: \d*\.\d*%"
134
+ lm = Transformers(model=model, tokenizer=tok)
135
+ lm += modified_query
136
+ lm += guidance_gen(
137
+ "gen_response",
138
+ n=1,
139
+ temperature=0,
140
+ max_tokens=1000,
141
+ # to try remove repetition, this is not a param in guidance
142
+ # repetition_penalty=1.2,
143
+ regex=regex,
144
+ )
145
+ return lm["gen_response"]
146
+
147
+
148
+ def batch_test(
149
+ query,
150
+ model,
151
+ tok,
152
+ gdc_genes_mutations,
153
+ project_mappings,
154
+ intent_model,
155
+ intent_tok
156
+ ):
157
+ modified_query = utilities.construct_modified_query_base_llm(query)
158
+ llama_base_output = generate_response(modified_query, model, tok)
159
+ try:
160
+ result = construct_and_execute_api_call(
161
+ query, gdc_genes_mutations, project_mappings, intent_model, intent_tok
162
+ )
163
+ except Exception as e:
164
+ # unable to compute at this time, recheck
165
+ result.helper_output = []
166
+ result.cancer_entities = []
167
+ # if there is not a helper output for each unique cancer entity
168
+ # log error to inspect and reprocess query later
169
+ try:
170
+ len(result.helper_output) == len(result.cancer_entities)
171
+ except Exception as e:
172
+ msg = "there is not a unique helper output for each unique \
173
+ cancer entity in {}".format(
174
+ query
175
+ )
176
+ print("exception {}".format(msg))
177
+ result.helper_output = []
178
+ result.cancer_entities = []
179
+
180
+ return pd.Series(
181
+ [
182
+ llama_base_output,
183
+ result.helper_output,
184
+ result.cancer_entities,
185
+ result.intent,
186
+ result.gene_entities,
187
+ result.mutation_entities,
188
+ ]
189
+ )
190
+
191
+
192
+ def setup_args():
193
+ parser = argparse.ArgumentParser()
194
+ # add functionality to either pass in a file with questions or a single question
195
+ group = parser.add_mutually_exclusive_group(required=True)
196
+ group.add_argument(
197
+ "--input-file",
198
+ dest="input_file",
199
+ help="path to input file with questions. input file should contain one column named questions, with each question on one line",
200
+ )
201
+ group.add_argument("--question", dest="question", help="a single question string")
202
+ return parser.parse_args()
203
+
204
+
205
+ def get_prefinal_response(row, model, tok):
206
+ try:
207
+ query = row["questions"]
208
+ helper_output = row["helper_output"]
209
+ except Exception as e:
210
+ print(f"unable to retrieve query: {query} or helper_output: {helper_output}")
211
+ modified_query = utilities.construct_modified_query(query, helper_output)
212
+ prefinal_llama_with_helper_output = generate_response(modified_query, model, tok)
213
+ return pd.Series([modified_query, prefinal_llama_with_helper_output])
214
+
215
+
216
+ def setup_models_and_data():
217
+ # from env
218
+ print("loading HF token")
219
+ AUTH_TOKEN = os.environ.get("HF_TOKEN") or True
220
+
221
+ print("getting gdc project information")
222
+ # retrieve and load GDC project mappings
223
+ project_mappings = gdc_api_calls.get_gdc_project_ids(start=0, stop=86)
224
+
225
+ print("loading gdc genes and mutations")
226
+ gdc_genes_mutations = utilities.load_gdc_genes_mutations_hf(AUTH_TOKEN)
227
+
228
+ print("loading llama-3B model")
229
+ model, tok = utilities.load_llama_llm(AUTH_TOKEN)
230
+
231
+ print('loading intent model')
232
+ intent_model, intent_tok = utilities.load_intent_model_hf(AUTH_TOKEN)
233
+ return SimpleNamespace(
234
+ project_mappings=project_mappings,
235
+ gdc_genes_mutations=gdc_genes_mutations,
236
+ model=model,
237
+ tok=tok,
238
+ intent_model=intent_model,
239
+ intent_tok=intent_tok
240
+ )
241
+
242
+
243
+ @utilities.timeit
244
+ def execute_pipeline(
245
+ df, gdc_genes_mutations, model,
246
+ tok, intent_model, intent_tok,
247
+ project_mappings, output_file_prefix
248
+ ):
249
+ print("starting pipeline")
250
+
251
+ # queries input file
252
+ print(f"running test on input {df}")
253
+ df[
254
+ [
255
+ "llama_base_output",
256
+ "helper_output",
257
+ "cancer_entities",
258
+ "intent",
259
+ "gene_entities",
260
+ "mutation_entities",
261
+ ]
262
+ ] = df["questions"].progress_apply(
263
+ lambda x: batch_test(
264
+ x,
265
+ model,
266
+ tok,
267
+ gdc_genes_mutations,
268
+ project_mappings,
269
+ intent_model,
270
+ intent_tok
271
+ )
272
+ )
273
+
274
+ # retain responses with helper output
275
+ df["len_helper"] = df["helper_output"].apply(lambda x: len(x))
276
+ df_filtered = df[df["len_helper"] != 0]
277
+ df_filtered["len_ce"] = df_filtered["cancer_entities"].apply(lambda x: len(x))
278
+ # retain rows where one response is retrieved for each cancer entity
279
+ df_filtered["ce_eq_helper"] = df_filtered.apply(
280
+ lambda x: x["len_ce"] == x["len_helper"], axis=1
281
+ )
282
+ df_filtered = df_filtered[df_filtered["ce_eq_helper"]]
283
+ df_filtered_exploded = df_filtered.explode(
284
+ ["helper_output", "cancer_entities"], ignore_index=True
285
+ )
286
+ df_filtered_exploded[["modified_prompt", "pre_final_llama_with_helper_output"]] = (
287
+ df_filtered_exploded.progress_apply(
288
+ lambda x: get_prefinal_response(x, model, tok), axis=1
289
+ )
290
+ )
291
+
292
+ ### postprocess response
293
+ print("postprocessing response")
294
+ df_filtered_exploded[
295
+ [
296
+ "llama_base_stat",
297
+ "delta_llama",
298
+ "value_changed",
299
+ "ground_truth_stat",
300
+ "generated_stat_prefinal",
301
+ "delta_prefinal",
302
+ "generated_stat_final",
303
+ "delta_final",
304
+ "final_response",
305
+ ]
306
+ ] = df_filtered_exploded.progress_apply(
307
+ lambda x: utilities.postprocess_response(x), axis=1
308
+ )
309
+
310
+ final_columns = utilities.get_final_columns()
311
+
312
+ if output_file_prefix:
313
+ final_output = os.path.join("csvs", output_file_prefix + ".results.csv")
314
+ print("writing final results to {}".format(final_output))
315
+ df_filtered_exploded.to_csv(final_output, columns=final_columns)
316
+ result = df_filtered_exploded
317
+ else:
318
+ result = df_filtered_exploded[final_columns].T
319
+ print('result {}'.format(result))
320
+ print('completed')
321
+ return result
322
+
323
+
324
+ def main():
325
+ args = setup_args()
326
+ input_file = args.input_file or None
327
+ question = args.question or None
328
+
329
+ qag_requirements = setup_models_and_data()
330
+
331
+ if input_file:
332
+ df = pd.read_csv(input_file)
333
+ output_file_prefix = os.path.basename(input_file).split(".")[0]
334
+ execute_pipeline(
335
+ df,
336
+ qag_requirements.gdc_genes_mutations,
337
+ qag_requirements.model,
338
+ qag_requirements.tok,
339
+ qag_requirements.intent_model,
340
+ qag_requirements.intent_tok,
341
+ qag_requirements.project_mappings,
342
+ output_file_prefix
343
+ )
344
+ elif question:
345
+ df = pd.DataFrame({"questions": [question]})
346
+ execute_pipeline(
347
+ df,
348
+ qag_requirements.gdc_genes_mutations,
349
+ qag_requirements.model,
350
+ qag_requirements.tok,
351
+ qag_requirements.intent_model,
352
+ qag_requirements.intent_tok,
353
+ qag_requirements.project_mappings,
354
+ output_file_prefix=None
355
+ )
356
+
357
+
358
+ if __name__ == "__main__":
359
+ main()
poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "gdc-qag"
3
+ version = "0.1.0"
4
+ description = "a package to run query augmented generation on the genomic data commons"
5
+ authors = [
6
+ {name = "aartiv",email = "aartiv@uchicago.edu"}
7
+ ]
8
+ readme = "README.md"
9
+ requires-python = "^3.10"
10
+ dependencies = [
11
+ "spacy (==3.7.5)",
12
+ "en-core-sci-md @ https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.4/en_core_sci_md-0.5.4.tar.gz#sha256=7c8fc52542dd1452ffce00b045c1298e2c185b7cf84793f8e0ec941987c09808",
13
+ "en-ner-bc5cdr-md @ https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.4/en_ner_bc5cdr_md-0.5.4.tar.gz#sha256=ffc73130a710edf851206199720cb2c744a043e032f5da6ba4bb36863deca778",
14
+ "huggingface-hub (>=0.33.2,<0.34.0)",
15
+ "langchain (>=0.3.26,<0.4.0)",
16
+ "langchain-core (>=0.3.68,<0.4.0)",
17
+ "langchain-text-splitters (>=0.3.8,<0.4.0)",
18
+ "langsmith (>=0.4.4,<0.5.0)",
19
+ "matplotlib-inline (>=0.1.7,<0.2.0)",
20
+ "numpy (==1.26.4)",
21
+ "pandas (==2.2.3)",
22
+ "requests (>=2.32.4,<3.0.0)",
23
+ "torch (==2.5.1)",
24
+ "tqdm (>=4.67.1,<5.0.0)",
25
+ "transformers (==4.49.0)",
26
+ "uvicorn (>=0.35.0,<0.36.0)",
27
+ "uvloop (==0.21.0)",
28
+ "vllm (==0.7.2)",
29
+ "gradio (>=5.35.0,<6.0.0)",
30
+ "tabulate (>=0.9.0,<0.10.0)",
31
+ "guidance (>=0.2.4,<0.3.0)",
32
+ "spaces (>=0.37.1,<0.38.0)",
33
+ "matplotlib (>=3.10.3,<4.0.0)",
34
+ "scipy (==1.13.1)",
35
+ "seaborn (>=0.13.2,<0.14.0)",
36
+ "statannotations (>=0.7.2,<0.8.0)",
37
+ "mcp (>=1.12.0,<2.0.0)"
38
+ ]
39
+ package-mode = false
40
+
41
+
42
+ [build-system]
43
+ requires = ["poetry-core>=2.0.0,<3.0.0"]
44
+ build-backend = "poetry.core.masonry.api"
45
+
46
+ [tool.poetry.group.dev.dependencies]
47
+ pre-commit = "^4.2.0"
48
+