aatu18 commited on
Commit
e4ea637
·
verified ·
1 Parent(s): ef893cf

global init lm

Browse files
Files changed (1) hide show
  1. app.py +12 -9
app.py CHANGED
@@ -20,6 +20,9 @@ from transformers import (
20
 
21
  from methods import gdc_api_calls, utilities
22
 
 
 
 
23
  # set up various tokens
24
  hf_TOKEN = os.environ.get("hf_svc_ctds", False)
25
 
@@ -327,26 +330,26 @@ def construct_and_execute_api_call(query):
327
  def generate_percentage_response(modified_query):
328
  # set_seed(1042)
329
  regex = "The final response is: \d*\.\d*%"
330
- lm = Transformers(model=model, tokenizer=tok)
331
- lm += modified_query
332
- lm += guidance_gen("pct_response", n=1, temperature=0, max_tokens=40, regex=regex)
333
- return lm["pct_response"]
334
 
335
 
336
  # generate llama model descriptive response
337
  @utilities.timeit
338
  @spaces.GPU(duration=10)
339
- def generate_descriptive_response(modified_query):
340
- lm = Transformers(model=model, tokenizer=tok)
341
- lm += modified_query
342
- lm += guidance_gen(
343
  "desc_response",
344
  n=1,
345
  temperature=0,
346
  max_tokens=100,
347
  regex="^[^\\n]*[.\S+]$",
348
  )
349
- return lm["desc_response"]
350
 
351
 
352
  @utilities.timeit
 
20
 
21
  from methods import gdc_api_calls, utilities
22
 
23
+ # global init to test guidance speed up
24
+ lm = Transformers(model=model, tokenizer=tok)
25
+
26
  # set up various tokens
27
  hf_TOKEN = os.environ.get("hf_svc_ctds", False)
28
 
 
330
  def generate_percentage_response(modified_query):
331
  # set_seed(1042)
332
  regex = "The final response is: \d*\.\d*%"
333
+ session = lm.copy()
334
+ session += modified_query
335
+ session += guidance_gen("pct_response", n=1, temperature=0, max_tokens=40, regex=regex)
336
+ return session["pct_response"]
337
 
338
 
339
  # generate llama model descriptive response
340
  @utilities.timeit
341
  @spaces.GPU(duration=10)
342
+ def generate_descriptive_response(modified_query):
343
+ session = lm.copy()
344
+ session += modified_query
345
+ session += guidance_gen(
346
  "desc_response",
347
  n=1,
348
  temperature=0,
349
  max_tokens=100,
350
  regex="^[^\\n]*[.\S+]$",
351
  )
352
+ return session["desc_response"]
353
 
354
 
355
  @utilities.timeit