Spaces:
Build error
Build error
Commit
·
a249507
1
Parent(s):
4718ba4
Handle input errors
Browse files
app.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
from transformers import AutoModelForCausalLM
|
| 2 |
import torch
|
| 3 |
import gradio as gr
|
|
|
|
| 4 |
|
| 5 |
model = AutoModelForCausalLM.from_pretrained("Manuel2011/addition_model")
|
| 6 |
|
|
@@ -22,7 +23,14 @@ class NumberTokenizer:
|
|
| 22 |
tokenizer = NumberTokenizer(13)
|
| 23 |
|
| 24 |
def generate_solution(input, solution_length=6, model=model):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
model.eval()
|
|
|
|
| 26 |
input = torch.tensor(tokenizer(input))
|
| 27 |
input = input
|
| 28 |
solution = []
|
|
|
|
| 1 |
from transformers import AutoModelForCausalLM
|
| 2 |
import torch
|
| 3 |
import gradio as gr
|
| 4 |
+
import re
|
| 5 |
|
| 6 |
model = AutoModelForCausalLM.from_pretrained("Manuel2011/addition_model")
|
| 7 |
|
|
|
|
| 23 |
tokenizer = NumberTokenizer(13)
|
| 24 |
|
| 25 |
def generate_solution(input, solution_length=6, model=model):
|
| 26 |
+
try:
|
| 27 |
+
parsed_input = re.search(r'(\d)\s*\+\s*(\d)', input)
|
| 28 |
+
first_number = int(parsed_input.group(1))
|
| 29 |
+
second_number = int(parsed_input.group(2))
|
| 30 |
+
except:
|
| 31 |
+
return 'Invalid input'
|
| 32 |
model.eval()
|
| 33 |
+
input = f'{first_number} + {second_number} ='
|
| 34 |
input = torch.tensor(tokenizer(input))
|
| 35 |
input = input
|
| 36 |
solution = []
|