Spaces:
Sleeping
Sleeping
global init lm
Browse files
app.py
CHANGED
|
@@ -20,6 +20,9 @@ from transformers import (
|
|
| 20 |
|
| 21 |
from methods import gdc_api_calls, utilities
|
| 22 |
|
|
|
|
|
|
|
|
|
|
| 23 |
# set up various tokens
|
| 24 |
hf_TOKEN = os.environ.get("hf_svc_ctds", False)
|
| 25 |
|
|
@@ -327,26 +330,26 @@ def construct_and_execute_api_call(query):
|
|
| 327 |
def generate_percentage_response(modified_query):
|
| 328 |
# set_seed(1042)
|
| 329 |
regex = "The final response is: \d*\.\d*%"
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
return
|
| 334 |
|
| 335 |
|
| 336 |
# generate llama model descriptive response
|
| 337 |
@utilities.timeit
|
| 338 |
@spaces.GPU(duration=10)
|
| 339 |
-
def generate_descriptive_response(modified_query):
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
"desc_response",
|
| 344 |
n=1,
|
| 345 |
temperature=0,
|
| 346 |
max_tokens=100,
|
| 347 |
regex="^[^\\n]*[.\S+]$",
|
| 348 |
)
|
| 349 |
-
return
|
| 350 |
|
| 351 |
|
| 352 |
@utilities.timeit
|
|
|
|
| 20 |
|
| 21 |
from methods import gdc_api_calls, utilities
|
| 22 |
|
| 23 |
+
# global init to test guidance speed up
|
| 24 |
+
lm = Transformers(model=model, tokenizer=tok)
|
| 25 |
+
|
| 26 |
# set up various tokens
|
| 27 |
hf_TOKEN = os.environ.get("hf_svc_ctds", False)
|
| 28 |
|
|
|
|
| 330 |
def generate_percentage_response(modified_query):
|
| 331 |
# set_seed(1042)
|
| 332 |
regex = "The final response is: \d*\.\d*%"
|
| 333 |
+
session = lm.copy()
|
| 334 |
+
session += modified_query
|
| 335 |
+
session += guidance_gen("pct_response", n=1, temperature=0, max_tokens=40, regex=regex)
|
| 336 |
+
return session["pct_response"]
|
| 337 |
|
| 338 |
|
| 339 |
# generate llama model descriptive response
|
| 340 |
@utilities.timeit
|
| 341 |
@spaces.GPU(duration=10)
|
| 342 |
+
def generate_descriptive_response(modified_query):
|
| 343 |
+
session = lm.copy()
|
| 344 |
+
session += modified_query
|
| 345 |
+
session += guidance_gen(
|
| 346 |
"desc_response",
|
| 347 |
n=1,
|
| 348 |
temperature=0,
|
| 349 |
max_tokens=100,
|
| 350 |
regex="^[^\\n]*[.\S+]$",
|
| 351 |
)
|
| 352 |
+
return session["desc_response"]
|
| 353 |
|
| 354 |
|
| 355 |
@utilities.timeit
|