from typing import List import torch.nn as nn import torch from torchvision import transforms import clip from PIL import Image import os from app.api.dto.kg_query import PredictedLabel CLASS_NAMES = ['benhVerticilliumWiltCaChua', 'benhChayLaCaChua', 'benhXoanLaCaChua', 'benhDomLaCaChua', 'benhNhenXanhSan', 'benhKhamLaSan', 'cassava healthy', 'benhDomNau', 'boCanhCungHaiLaNgo', 'corn healthy', 'benhChayLaNgo', 'benhRiSatNgo', 'benhSocLaNgo', 'benhDomLaNgo', 'benhBacLaLua', 'benhDaoOnLua', 'benhDomNauLuaNuoc'] CROP_NAMES = ['caChua', 'caChua', 'caChua', 'caChua', 'san', 'san', 'san', 'san', 'ngo', 'ngo', 'ngo', 'ngo', 'ngo', 'ngo', 'luaNuoc', 'luaNuoc', 'luaNuoc'] WEIGHTS_PATH = os.path.join(os.path.dirname(__file__), 'weights', 'clip_finetuned.pth') class CLIPFineTuner(nn.Module): def __init__(self, model, num_classes): super(CLIPFineTuner, self).__init__() self.model = model self.classifier = nn.Linear(model.visual.output_dim, num_classes) def forward(self, x): with torch.no_grad(): features = self.model.encode_image(x).float() # Convert to float32 return self.classifier(features) class CLIPModule: def __init__(self): model, preprocess = clip.load("ViT-B/32", jit=False) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = CLIPFineTuner(model, 17) self.model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=self.device)) self.model.to(self.device) self.model.eval() self.classes = CLASS_NAMES self.transform = preprocess def predict_image(self, image: Image.Image): output = self.__predict(image) probabilities = torch.nn.functional.softmax(output, dim=1)[0] predictions: List[PredictedLabel] = [] for idx, prob in enumerate(probabilities): predictions.append(PredictedLabel( crop_name=CROP_NAMES[idx], label=self.classes[idx], confidence=float(prob) )) # Sắp xếp giảm dần theo xác suất predictions.sort(key=lambda x: x.confidence, reverse=True) return predictions def __predict(self, image_input): """ Dự đoán nhãn cho một ảnh. Args: image_input: Đường dẫn file ảnh (str) hoặc đối tượng PIL.Image device: Thiết bị chạy mô hình ('cuda' hoặc 'cpu'). Returns: str: Nhãn dự đoán (e.g., "cassava_leaf beetle"). """ try: image = self.__handle_image(image_input) image_tensor = self.transform(image) except ValueError as e: raise e except Exception as e: raise ValueError(f"Không thể xử lý ảnh đầu vào: {str(e)}") if image_tensor.dim() == 3: image_tensor = image_tensor.unsqueeze(0) print(image_tensor.shape) image_tensor = image_tensor.to(self.device) with torch.no_grad(): output = self.model(image_tensor) return output ## an array of 17 values, no softmax def __handle_image(self, image_input): if isinstance(image_input, str): image = Image.open(image_input).convert('RGB') elif isinstance(image_input, Image.Image): image = image_input else: raise ValueError("Invalid image input") return image