ckandemir commited on
Commit
7bfb35d
·
1 Parent(s): cf258e1

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +32 -30
handler.py CHANGED
@@ -1,36 +1,38 @@
1
- import base64
2
- import requests
3
- from PIL import Image
4
- import warnings
5
- from typing import Dict, List, Any, Union
6
- import torch
7
- from io import BytesIO
8
- from transformers import BlipProcessor, BlipForConditionalGeneration, BitsAndBytesConfig
9
 
10
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- class EndpointHandler():
13
- def __init__(self, model_dir="Salesforce/blip-image-captioning-large"):
14
- self.model = BlipForConditionalGeneration.from_pretrained(model_dir).to(device).eval()
15
- self.processor = BlipProcessor.from_pretrained(model_dir)
 
 
 
 
 
16
 
17
- def __call__(self, data):
18
- input_data = data['inputs'][0]
19
- img_url = input_data.get('img_url')
20
- text_prompt = input_data.get('text', None)
21
-
22
- raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
23
-
24
- if text_prompt:
25
- inputs = self.processor(raw_image, text_prompt, return_tensors="pt").to(device)
26
- else:
27
- inputs = self.processor(raw_image, return_tensors="pt").to(device)
28
-
29
  with torch.no_grad():
30
- generated_ids = self.model.generate(
31
- **inputs,
32
- max_new_tokens=150
33
- )
34
- captions = self.processor.decode(generated_ids[0], skip_special_tokens=True)
35
 
36
  return {"captions": captions}
 
1
+ def __init__(self, path=""):
2
+ # load the optimized model
3
+ self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
4
+ self.model = BlipForConditionalGeneration.from_pretrained(
5
+ "Salesforce/blip-image-captioning-large"
6
+ ).to(device)
7
+ self.model.eval()
8
+ self.model = self.model.to(device)
9
 
10
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
11
+ """
12
+ Args:
13
+ data (dict):
14
+ Should contain:
15
+ - 'images': List[bytes] of images.
16
+ - 'texts': List[str] of associated texts. (Optional for unconditional captioning)
17
+ Return:
18
+ A dict with key "captions" and associated list of generated captions.
19
+ """
20
+ images = data.get("images")
21
+ texts = data.get("texts", ["a photography of"] * len(images)) # Default to "a photography of" if not provided
22
 
23
+ raw_images = [Image.open(BytesIO(_img)).convert("RGB") for _img in images]
24
+
25
+ # Here, process both image and text
26
+ processed_inputs = [self.processor(img, txt, return_tensors="pt") for img, txt in zip(raw_images, texts)]
27
+ processed_inputs = {
28
+ "pixel_values": torch.cat([inp["pixel_values"] for inp in processed_inputs], dim=0).to(device),
29
+ "input_ids": torch.cat([inp["input_ids"] for inp in processed_inputs], dim=0).to(device),
30
+ "attention_mask": torch.cat([inp["attention_mask"] for inp in processed_inputs], dim=0).to(device)
31
+ }
32
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  with torch.no_grad():
34
+ out = self.model.generate(**processed_inputs)
35
+
36
+ captions = self.processor.batch_decode(out, skip_special_tokens=True)
 
 
37
 
38
  return {"captions": captions}