File size: 3,547 Bytes
88cc76c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from typing import List
import torch.nn as nn
import torch
from torchvision import transforms
import clip
from PIL import Image
import os

from app.api.dto.kg_query import PredictedLabel

CLASS_NAMES = ['benhVerticilliumWiltCaChua', 'benhChayLaCaChua', 'benhXoanLaCaChua', 'benhDomLaCaChua',
               'benhNhenXanhSan', 'benhKhamLaSan', 'cassava healthy', 'benhDomNau',
               'boCanhCungHaiLaNgo', 'corn healthy', 'benhChayLaNgo', 'benhRiSatNgo', 'benhSocLaNgo',
               'benhDomLaNgo', 'benhBacLaLua', 'benhDaoOnLua', 'benhDomNauLuaNuoc']

CROP_NAMES = ['caChua', 'caChua', 'caChua', 'caChua', 'san', 'san', 'san', 'san',
              'ngo', 'ngo', 'ngo', 'ngo', 'ngo', 'ngo', 'luaNuoc', 'luaNuoc', 'luaNuoc']

WEIGHTS_PATH = os.path.join(os.path.dirname(__file__), 'weights', 'clip_finetuned.pth')

class CLIPFineTuner(nn.Module):
    def __init__(self, model, num_classes):
        super(CLIPFineTuner, self).__init__()
        self.model = model
        self.classifier = nn.Linear(model.visual.output_dim, num_classes)

    def forward(self, x):
        with torch.no_grad():
            features = self.model.encode_image(x).float()  # Convert to float32
        return self.classifier(features)

class CLIPModule:
    def __init__(self):
        model, preprocess = clip.load("ViT-B/32", jit=False)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = CLIPFineTuner(model, 17)
        self.model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=self.device))
        self.model.to(self.device)
        self.model.eval()
        self.classes = CLASS_NAMES
        self.transform = preprocess

    def predict_image(self, image: Image.Image):
        output = self.__predict(image)
        probabilities = torch.nn.functional.softmax(output, dim=1)[0]

        predictions: List[PredictedLabel] = []
        for idx, prob in enumerate(probabilities):
            predictions.append(PredictedLabel(
                crop_name=CROP_NAMES[idx],
                label=self.classes[idx],
                confidence=float(prob)
            ))

        # Sắp xếp giảm dần theo xác suất
        predictions.sort(key=lambda x: x.confidence, reverse=True)

        return predictions

    def __predict(self, image_input):
        """
        Dự đoán nhãn cho một ảnh.

        Args:
            image_input: Đường dẫn file ảnh (str) hoặc đối tượng PIL.Image
            device: Thiết bị chạy mô hình ('cuda' hoặc 'cpu').

        Returns:
            str: Nhãn dự đoán (e.g., "cassava_leaf beetle").
        """
        try:
            image = self.__handle_image(image_input)
            image_tensor = self.transform(image)
        except ValueError as e:
            raise e
        except Exception as e:
            raise ValueError(f"Không thể xử lý ảnh đầu vào: {str(e)}")

        if image_tensor.dim() == 3:
            image_tensor = image_tensor.unsqueeze(0)

        print(image_tensor.shape)

        image_tensor = image_tensor.to(self.device)

        with torch.no_grad():
            output = self.model(image_tensor)

        return output ## an array of 17 values, no softmax

    def __handle_image(self, image_input):
        if isinstance(image_input, str):
            image = Image.open(image_input).convert('RGB')
        elif isinstance(image_input, Image.Image):
            image = image_input
        else:
            raise ValueError("Invalid image input")
        return image