Spaces:
Sleeping
Sleeping
File size: 3,547 Bytes
88cc76c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
from typing import List
import torch.nn as nn
import torch
from torchvision import transforms
import clip
from PIL import Image
import os
from app.api.dto.kg_query import PredictedLabel
CLASS_NAMES = ['benhVerticilliumWiltCaChua', 'benhChayLaCaChua', 'benhXoanLaCaChua', 'benhDomLaCaChua',
'benhNhenXanhSan', 'benhKhamLaSan', 'cassava healthy', 'benhDomNau',
'boCanhCungHaiLaNgo', 'corn healthy', 'benhChayLaNgo', 'benhRiSatNgo', 'benhSocLaNgo',
'benhDomLaNgo', 'benhBacLaLua', 'benhDaoOnLua', 'benhDomNauLuaNuoc']
CROP_NAMES = ['caChua', 'caChua', 'caChua', 'caChua', 'san', 'san', 'san', 'san',
'ngo', 'ngo', 'ngo', 'ngo', 'ngo', 'ngo', 'luaNuoc', 'luaNuoc', 'luaNuoc']
WEIGHTS_PATH = os.path.join(os.path.dirname(__file__), 'weights', 'clip_finetuned.pth')
class CLIPFineTuner(nn.Module):
def __init__(self, model, num_classes):
super(CLIPFineTuner, self).__init__()
self.model = model
self.classifier = nn.Linear(model.visual.output_dim, num_classes)
def forward(self, x):
with torch.no_grad():
features = self.model.encode_image(x).float() # Convert to float32
return self.classifier(features)
class CLIPModule:
def __init__(self):
model, preprocess = clip.load("ViT-B/32", jit=False)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = CLIPFineTuner(model, 17)
self.model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=self.device))
self.model.to(self.device)
self.model.eval()
self.classes = CLASS_NAMES
self.transform = preprocess
def predict_image(self, image: Image.Image):
output = self.__predict(image)
probabilities = torch.nn.functional.softmax(output, dim=1)[0]
predictions: List[PredictedLabel] = []
for idx, prob in enumerate(probabilities):
predictions.append(PredictedLabel(
crop_name=CROP_NAMES[idx],
label=self.classes[idx],
confidence=float(prob)
))
# Sắp xếp giảm dần theo xác suất
predictions.sort(key=lambda x: x.confidence, reverse=True)
return predictions
def __predict(self, image_input):
"""
Dự đoán nhãn cho một ảnh.
Args:
image_input: Đường dẫn file ảnh (str) hoặc đối tượng PIL.Image
device: Thiết bị chạy mô hình ('cuda' hoặc 'cpu').
Returns:
str: Nhãn dự đoán (e.g., "cassava_leaf beetle").
"""
try:
image = self.__handle_image(image_input)
image_tensor = self.transform(image)
except ValueError as e:
raise e
except Exception as e:
raise ValueError(f"Không thể xử lý ảnh đầu vào: {str(e)}")
if image_tensor.dim() == 3:
image_tensor = image_tensor.unsqueeze(0)
print(image_tensor.shape)
image_tensor = image_tensor.to(self.device)
with torch.no_grad():
output = self.model(image_tensor)
return output ## an array of 17 values, no softmax
def __handle_image(self, image_input):
if isinstance(image_input, str):
image = Image.open(image_input).convert('RGB')
elif isinstance(image_input, Image.Image):
image = image_input
else:
raise ValueError("Invalid image input")
return image
|