Stepan222 commited on
Commit
eb8769a
·
verified ·
1 Parent(s): de2ea0b

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +24 -56
app.py CHANGED
@@ -5,7 +5,6 @@ from peft import PeftModel
5
  from PIL import Image
6
  import requests
7
  from io import BytesIO
8
- import gc
9
 
10
  model = None
11
  processor = None
@@ -13,23 +12,19 @@ processor = None
13
  def load_model():
14
  global model, processor
15
  if model is None:
16
- print("Загружаю базовую модель Qwen2.5-VL-7B-Instruct...")
17
- gc.collect()
18
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
19
 
20
  base_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
21
  "Qwen/Qwen2.5-VL-7B-Instruct",
22
- torch_dtype=torch.bfloat16,
23
  device_map="auto",
24
- trust_remote_code=True,
25
- attn_implementation="flash_attention_2"
26
  )
27
 
28
- print("Применяю LoRA адаптер...")
29
  model = PeftModel.from_pretrained(
30
  base_model,
31
- "Stepan222/oem-fake-classifier-qwen2vl",
32
- torch_dtype=torch.bfloat16
33
  )
34
  model.eval()
35
 
@@ -37,79 +32,52 @@ def load_model():
37
  "Qwen/Qwen2.5-VL-7B-Instruct",
38
  trust_remote_code=True
39
  )
40
- print("Модель загружена!")
41
  return model, processor
42
 
43
  def classify(image_url: str, title: str, description: str = ""):
44
  try:
45
  model, processor = load_model()
46
  except Exception as e:
47
- return f"Ошибка загрузки модели: {e}"
48
 
49
  try:
50
- if image_url.startswith("http"):
51
- response = requests.get(image_url, timeout=10)
52
- image = Image.open(BytesIO(response.content)).convert("RGB")
53
- else:
54
- return "Введите URL изображения"
55
- except Exception as e:
56
- return f"Ошибка загрузки изображения: {e}"
57
 
58
  text = f"Title: {title}"
59
  if description:
60
  text += f"\nDescription: {description}"
61
 
62
- prompt = f"""Analyze this eBay listing for auto parts. Classify as OEM (original manufacturer) or FAKE (aftermarket/counterfeit).
63
 
64
  {text}
65
 
66
- Look at the image and text carefully. Respond with:
67
- - OEM: [confidence]% - [brief reason]
68
- - FAKE: [confidence]% - [brief reason]"""
69
 
70
- messages = [
71
- {
72
- "role": "user",
73
- "content": [
74
- {"type": "image", "image": image},
75
- {"type": "text", "text": prompt}
76
- ]
77
- }
78
- ]
79
 
80
  text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
81
- inputs = processor(
82
- text=[text_input],
83
- images=[image],
84
- return_tensors="pt",
85
- padding=True
86
- ).to(model.device)
87
 
88
  with torch.no_grad():
89
- outputs = model.generate(
90
- **inputs,
91
- max_new_tokens=100,
92
- do_sample=False
93
- )
94
-
95
- response = processor.batch_decode(outputs, skip_special_tokens=True)[0]
96
- if "assistant" in response.lower():
97
- response = response.split("assistant")[-1].strip()
98
 
99
- return response
 
100
 
101
  demo = gr.Interface(
102
  fn=classify,
103
  inputs=[
104
- gr.Textbox(label="URL изображения", placeholder="https://i.ebayimg.com/..."),
105
- gr.Textbox(label="Название товара", placeholder="BMW Genuine OEM Part..."),
106
- gr.Textbox(label="Описание (опционально)", placeholder="", lines=2)
107
  ],
108
- outputs=gr.Textbox(label="Результат", lines=3),
109
- title="OEM/Fake Classifier (Qwen2.5-VL + LoRA)",
110
- description="Классификатор автозапчастей. Space засыпает через 5 минут бездействия для экономии.",
111
  allow_flagging="never"
112
  )
113
 
114
- if __name__ == "__main__":
115
- demo.launch()
 
5
  from PIL import Image
6
  import requests
7
  from io import BytesIO
 
8
 
9
  model = None
10
  processor = None
 
12
  def load_model():
13
  global model, processor
14
  if model is None:
15
+ print("Загружаю модель...")
 
 
16
 
17
  base_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
18
  "Qwen/Qwen2.5-VL-7B-Instruct",
19
+ torch_dtype=torch.float16,
20
  device_map="auto",
21
+ trust_remote_code=True
 
22
  )
23
 
24
+ print("Применяю LoRA...")
25
  model = PeftModel.from_pretrained(
26
  base_model,
27
+ "Stepan222/oem-fake-classifier-qwen2vl"
 
28
  )
29
  model.eval()
30
 
 
32
  "Qwen/Qwen2.5-VL-7B-Instruct",
33
  trust_remote_code=True
34
  )
35
+ print("Модель готова!")
36
  return model, processor
37
 
38
  def classify(image_url: str, title: str, description: str = ""):
39
  try:
40
  model, processor = load_model()
41
  except Exception as e:
42
+ return f"Ошибка модели: {e}"
43
 
44
  try:
45
+ response = requests.get(image_url, timeout=10)
46
+ image = Image.open(BytesIO(response.content)).convert("RGB")
47
+ except:
48
+ return "Не удалось загрузить изображение"
 
 
 
49
 
50
  text = f"Title: {title}"
51
  if description:
52
  text += f"\nDescription: {description}"
53
 
54
+ prompt = f"""Analyze this eBay auto part listing. Is it OEM (original) or FAKE (aftermarket)?
55
 
56
  {text}
57
 
58
+ Reply: OEM or FAKE with confidence % and reason."""
 
 
59
 
60
+ messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}]
 
 
 
 
 
 
 
 
61
 
62
  text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
63
+ inputs = processor(text=[text_input], images=[image], return_tensors="pt", padding=True).to(model.device)
 
 
 
 
 
64
 
65
  with torch.no_grad():
66
+ out = model.generate(**inputs, max_new_tokens=80, do_sample=False)
 
 
 
 
 
 
 
 
67
 
68
+ resp = processor.batch_decode(out, skip_special_tokens=True)[0]
69
+ return resp.split("assistant")[-1].strip() if "assistant" in resp.lower() else resp
70
 
71
  demo = gr.Interface(
72
  fn=classify,
73
  inputs=[
74
+ gr.Textbox(label="Image URL"),
75
+ gr.Textbox(label="Title"),
76
+ gr.Textbox(label="Description")
77
  ],
78
+ outputs=gr.Textbox(label="Result"),
79
+ title="OEM/Fake Classifier",
 
80
  allow_flagging="never"
81
  )
82
 
83
+ demo.launch()