crop-diag-module / app /models /crop_clip.py
Sontranwakumo
init: move from github
88cc76c
raw
history blame
3.55 kB
from typing import List
import torch.nn as nn
import torch
from torchvision import transforms
import clip
from PIL import Image
import os
from app.api.dto.kg_query import PredictedLabel
CLASS_NAMES = ['benhVerticilliumWiltCaChua', 'benhChayLaCaChua', 'benhXoanLaCaChua', 'benhDomLaCaChua',
'benhNhenXanhSan', 'benhKhamLaSan', 'cassava healthy', 'benhDomNau',
'boCanhCungHaiLaNgo', 'corn healthy', 'benhChayLaNgo', 'benhRiSatNgo', 'benhSocLaNgo',
'benhDomLaNgo', 'benhBacLaLua', 'benhDaoOnLua', 'benhDomNauLuaNuoc']
CROP_NAMES = ['caChua', 'caChua', 'caChua', 'caChua', 'san', 'san', 'san', 'san',
'ngo', 'ngo', 'ngo', 'ngo', 'ngo', 'ngo', 'luaNuoc', 'luaNuoc', 'luaNuoc']
WEIGHTS_PATH = os.path.join(os.path.dirname(__file__), 'weights', 'clip_finetuned.pth')
class CLIPFineTuner(nn.Module):
def __init__(self, model, num_classes):
super(CLIPFineTuner, self).__init__()
self.model = model
self.classifier = nn.Linear(model.visual.output_dim, num_classes)
def forward(self, x):
with torch.no_grad():
features = self.model.encode_image(x).float() # Convert to float32
return self.classifier(features)
class CLIPModule:
def __init__(self):
model, preprocess = clip.load("ViT-B/32", jit=False)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = CLIPFineTuner(model, 17)
self.model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=self.device))
self.model.to(self.device)
self.model.eval()
self.classes = CLASS_NAMES
self.transform = preprocess
def predict_image(self, image: Image.Image):
output = self.__predict(image)
probabilities = torch.nn.functional.softmax(output, dim=1)[0]
predictions: List[PredictedLabel] = []
for idx, prob in enumerate(probabilities):
predictions.append(PredictedLabel(
crop_name=CROP_NAMES[idx],
label=self.classes[idx],
confidence=float(prob)
))
# Sắp xếp giảm dần theo xác suất
predictions.sort(key=lambda x: x.confidence, reverse=True)
return predictions
def __predict(self, image_input):
"""
Dự đoán nhãn cho một ảnh.
Args:
image_input: Đường dẫn file ảnh (str) hoặc đối tượng PIL.Image
device: Thiết bị chạy mô hình ('cuda' hoặc 'cpu').
Returns:
str: Nhãn dự đoán (e.g., "cassava_leaf beetle").
"""
try:
image = self.__handle_image(image_input)
image_tensor = self.transform(image)
except ValueError as e:
raise e
except Exception as e:
raise ValueError(f"Không thể xử lý ảnh đầu vào: {str(e)}")
if image_tensor.dim() == 3:
image_tensor = image_tensor.unsqueeze(0)
print(image_tensor.shape)
image_tensor = image_tensor.to(self.device)
with torch.no_grad():
output = self.model(image_tensor)
return output ## an array of 17 values, no softmax
def __handle_image(self, image_input):
if isinstance(image_input, str):
image = Image.open(image_input).convert('RGB')
elif isinstance(image_input, Image.Image):
image = image_input
else:
raise ValueError("Invalid image input")
return image