Update handler.py
Browse files- handler.py +3 -2
handler.py
CHANGED
|
@@ -62,14 +62,15 @@ class EndpointHandler:
|
|
| 62 |
|
| 63 |
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
| 64 |
embedding = data.pop("embedding", None)
|
|
|
|
| 65 |
max_length=200
|
| 66 |
with torch.no_grad():
|
| 67 |
-
outputs = self.model(ada_embedding=
|
| 68 |
decoded_tkns = outputs.logits.argmax(dim=-1)
|
| 69 |
|
| 70 |
for _ in range(max_length):
|
| 71 |
with torch.no_grad():
|
| 72 |
-
outputs = self.model(ada_embedding=
|
| 73 |
|
| 74 |
# Get the most likely next token, sampled from top k
|
| 75 |
logits = outputs.logits[:, -1]
|
|
|
|
| 62 |
|
| 63 |
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
| 64 |
embedding = data.pop("embedding", None)
|
| 65 |
+
ada_embedding = torch.tensor(embedding).unsqueeze(0)
|
| 66 |
max_length=200
|
| 67 |
with torch.no_grad():
|
| 68 |
+
outputs = self.model(ada_embedding=ada_embedding, decoded_tkns=None)
|
| 69 |
decoded_tkns = outputs.logits.argmax(dim=-1)
|
| 70 |
|
| 71 |
for _ in range(max_length):
|
| 72 |
with torch.no_grad():
|
| 73 |
+
outputs = self.model(ada_embedding=ada_embedding, decoded_tkns=decoded_tkns)
|
| 74 |
|
| 75 |
# Get the most likely next token, sampled from top k
|
| 76 |
logits = outputs.logits[:, -1]
|