echeyde commited on
Commit
4f2e221
·
verified ·
1 Parent(s): 5797dc2

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +81 -0
handler.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+ import torch
3
+
4
+
5
+ class EndpointHandler:
6
+ def __init__(self):
7
+ # Initialize model and tokenizer
8
+ self.tokenizer = AutoTokenizer.from_pretrained("VisitationAI/opt125-llama-visitation")
9
+ self.model = AutoModelForCausalLM.from_pretrained("VisitationAI/opt125-llama-visitation")
10
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ self.model.to(self.device)
12
+
13
+ def __call__(self, data):
14
+ """
15
+ Args:
16
+ data: JSON input with structure:
17
+ {
18
+ "inputs": "your text prompt here",
19
+ "parameters": {
20
+ "max_new_tokens": 50,
21
+ "temperature": 0.7,
22
+ "top_p": 0.9,
23
+ "do_sample": true
24
+ }
25
+ }
26
+ """
27
+ # Get input text and parameters
28
+ inputs = data.pop("inputs", data)
29
+ parameters = data.pop("parameters", {})
30
+
31
+ # Default generation parameters
32
+ generation_config = {
33
+ "max_new_tokens": parameters.get("max_new_tokens", 50),
34
+ "temperature": parameters.get("temperature", 0.7),
35
+ "top_p": parameters.get("top_p", 0.9),
36
+ "do_sample": parameters.get("do_sample", True),
37
+ "pad_token_id": self.tokenizer.eos_token_id,
38
+ "num_return_sequences": parameters.get("num_return_sequences", 1)
39
+ }
40
+
41
+ # Tokenize
42
+ inputs = self.tokenizer(
43
+ inputs,
44
+ return_tensors="pt",
45
+ padding=True,
46
+ truncation=True,
47
+ max_length=512
48
+ ).to(self.device)
49
+
50
+ # Generate text
51
+ with torch.no_grad():
52
+ generated_ids = self.model.generate(
53
+ inputs.input_ids,
54
+ attention_mask=inputs.attention_mask,
55
+ **generation_config
56
+ )
57
+
58
+ # Decode and return generated text
59
+ generated_texts = self.tokenizer.batch_decode(
60
+ generated_ids,
61
+ skip_special_tokens=True
62
+ )
63
+
64
+ return {
65
+ "generated_text": generated_texts[0], # Return first generation if multiple
66
+ "all_generations": generated_texts # All generations if num_return_sequences > 1
67
+ }
68
+
69
+ def preprocess(self, data):
70
+ """
71
+ Handle different input formats
72
+ """
73
+ if isinstance(data, str):
74
+ return {"inputs": data}
75
+ return data
76
+
77
+ def postprocess(self, data):
78
+ """
79
+ Clean up output if needed
80
+ """
81
+ return data