Spaces:
Sleeping
Sleeping
| from typing import List | |
| import torch.nn as nn | |
| import torch | |
| from torchvision import transforms | |
| import clip | |
| from PIL import Image | |
| import os | |
| from app.api.dto.kg_query import PredictedLabel | |
| CLASS_NAMES = ['benhVerticilliumWiltCaChua', 'benhChayLaCaChua', 'benhXoanLaCaChua', 'benhDomLaCaChua', | |
| 'benhNhenXanhSan', 'benhKhamLaSan', 'cassava healthy', 'benhDomNau', | |
| 'boCanhCungHaiLaNgo', 'corn healthy', 'benhChayLaNgo', 'benhRiSatNgo', 'benhSocLaNgo', | |
| 'benhDomLaNgo', 'benhBacLaLua', 'benhDaoOnLua', 'benhDomNauLuaNuoc'] | |
| CROP_NAMES = ['caChua', 'caChua', 'caChua', 'caChua', 'san', 'san', 'san', 'san', | |
| 'ngo', 'ngo', 'ngo', 'ngo', 'ngo', 'ngo', 'luaNuoc', 'luaNuoc', 'luaNuoc'] | |
| WEIGHTS_PATH = os.path.join(os.path.dirname(__file__), 'weights', 'clip_finetuned.pth') | |
| class CLIPFineTuner(nn.Module): | |
| def __init__(self, model, num_classes): | |
| super(CLIPFineTuner, self).__init__() | |
| self.model = model | |
| self.classifier = nn.Linear(model.visual.output_dim, num_classes) | |
| def forward(self, x): | |
| with torch.no_grad(): | |
| features = self.model.encode_image(x).float() # Convert to float32 | |
| return self.classifier(features) | |
| class CLIPModule: | |
| def __init__(self): | |
| model, preprocess = clip.load("ViT-B/32", jit=False) | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model = CLIPFineTuner(model, 17) | |
| self.model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=self.device)) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| self.classes = CLASS_NAMES | |
| self.transform = preprocess | |
| def predict_image(self, image: Image.Image): | |
| output = self.__predict(image) | |
| probabilities = torch.nn.functional.softmax(output, dim=1)[0] | |
| predictions: List[PredictedLabel] = [] | |
| for idx, prob in enumerate(probabilities): | |
| predictions.append(PredictedLabel( | |
| crop_name=CROP_NAMES[idx], | |
| label=self.classes[idx], | |
| confidence=float(prob) | |
| )) | |
| # Sắp xếp giảm dần theo xác suất | |
| predictions.sort(key=lambda x: x.confidence, reverse=True) | |
| return predictions | |
| def __predict(self, image_input): | |
| """ | |
| Dự đoán nhãn cho một ảnh. | |
| Args: | |
| image_input: Đường dẫn file ảnh (str) hoặc đối tượng PIL.Image | |
| device: Thiết bị chạy mô hình ('cuda' hoặc 'cpu'). | |
| Returns: | |
| str: Nhãn dự đoán (e.g., "cassava_leaf beetle"). | |
| """ | |
| try: | |
| image = self.__handle_image(image_input) | |
| image_tensor = self.transform(image) | |
| except ValueError as e: | |
| raise e | |
| except Exception as e: | |
| raise ValueError(f"Không thể xử lý ảnh đầu vào: {str(e)}") | |
| if image_tensor.dim() == 3: | |
| image_tensor = image_tensor.unsqueeze(0) | |
| print(image_tensor.shape) | |
| image_tensor = image_tensor.to(self.device) | |
| with torch.no_grad(): | |
| output = self.model(image_tensor) | |
| return output ## an array of 17 values, no softmax | |
| def __handle_image(self, image_input): | |
| if isinstance(image_input, str): | |
| image = Image.open(image_input).convert('RGB') | |
| elif isinstance(image_input, Image.Image): | |
| image = image_input | |
| else: | |
| raise ValueError("Invalid image input") | |
| return image | |