Sontranwakumo
feat: fake data
80ad49a
import json
import io
from string import Template
from fastapi import Depends, UploadFile
import asyncio
from PIL import Image
import sqlite3
from app.api.dto.kg_query import KGQueryRequest, QueryContext, PredictedLabel
from app.core.dependencies import get_all_models
from app.core.type import Node
from app.models.crop_clip import EfficientNetModule
from app.models.gemini_caller import GeminiGenerator
from app.models.knowledge_graph import KnowledgeGraphUtils
from app.utils.constant import EXTRACTED_NODES
from app.utils.data_mapping import VECTOR_EMBEDDINGS_DB_PATH, DataMapping
from app.utils.extract_entity import clean_text, extract_entities
from app.utils.prompt import EXTRACT_NODES_FROM_IMAGE_PROMPT, EXTRACT_NODES_FROM_TEXT_PROMPT, GET_CAPTION_FROM_IMAGE_PROMPT, GET_STATEMENT_FROM_DISEASE_KG, GET_STATEMENT_FROM_ENV_FACTORS_KG, INSTRUCTION
from app.utils.main import unique_captions_by_disease
class CustomJSONEncoder(json.JSONEncoder):
def default(self, obj):
if hasattr(obj, 'model_dump'): # Pydantic v2 BaseModel
return obj.model_dump()
elif hasattr(obj, 'dict'): # Pydantic v1 BaseModel
return obj.dict()
elif isinstance(obj, (list, tuple)):
return [self.default(item) if hasattr(item, 'model_dump') or hasattr(item, 'dict') else item for item in obj]
return super().default(obj)
def convert_to_json_serializable(obj):
"""Convert objects containing Node instances to JSON serializable format"""
try:
if hasattr(obj, 'model_dump'): # Pydantic v2 BaseModel
return obj.model_dump()
elif hasattr(obj, 'dict'): # Pydantic v1 BaseModel
return obj.dict()
elif isinstance(obj, list):
return [convert_to_json_serializable(item) for item in obj]
elif isinstance(obj, dict):
return {key: convert_to_json_serializable(value) for key, value in obj.items()}
elif isinstance(obj, tuple):
return [convert_to_json_serializable(item) for item in obj]
elif obj is None:
return None
else:
# Try to convert basic types
try:
json.dumps(obj) # Test if it's JSON serializable
return obj
except (TypeError, ValueError):
# If it's not serializable, convert to string as fallback
print(f"Warning: Converting non-serializable object {type(obj)} to string: {obj}")
return str(obj)
except Exception as e:
print(f"Error in convert_to_json_serializable for object {type(obj)}: {e}")
return str(obj)
extracted_nodes = [
Node(
id=node['id'],
label=node['label'],
name=node['name'],
properties={'description': node['description']},
score=None
) for node in EXTRACTED_NODES
]
class PredictService:
def __init__(self, models):
self.models = models
async def predict_image(self, image: UploadFile):
efficientnet_model: EfficientNetModule = self.models["efficientnet_model"]
image_content = image.file.read()
pil_image = Image.open(Image.io.BytesIO(image_content)).convert('RGB')
return efficientnet_model.predict_image(pil_image)
async def retrieve_kg(self, request: KGQueryRequest):
try:
kg: KnowledgeGraphUtils = self.models["knowledge_graph"]
if not request.context:
request.context = QueryContext()
if request.crop_id:
request.context.crop_id = request.crop_id
if request.additional_info:
additional_nodes = await self.__get_nodes_from_additional_info_async(
request.additional_info, self.models["data_mapper"]
)
if request.context.nodes is None:
request.context.nodes = []
request.context.nodes = request.context.nodes + additional_nodes
for node in request.context.nodes:
if node.score is None:
node.score = 0.9
env_task = asyncio.create_task(
kg.get_disease_from_env_factors(request.context.crop_id, request.context.nodes)
)
symptom_task = asyncio.create_task(
kg.get_disease_from_symptoms(request.context.crop_id, request.context.nodes)
)
env_results, symptom_results = await asyncio.gather(env_task, symptom_task)
context = request.context
context.nodes.extend([env_result["disease"] for env_result in env_results])
context.nodes.extend([symptom_result["disease"] for symptom_result in symptom_results])
print(context.nodes)
context.nodes.sort(key=lambda x: x.score, reverse=True)
# Tính toán final_labels bằng trung bình có trọng số
if context.predicted_labels:
print("Got predicted labels")
context.final_labels = self.calculate_final_labels(
context.predicted_labels,
env_results,
symptom_results,
context.crop_id
)
return {
"context": context,
"env_results": env_results,
"symptom_results": symptom_results
}
except Exception as e:
print(e)
raise e
def calculate_final_labels(self, predicted_labels, env_result, symptom_result, crop_id):
"""
Tính toán final_labels bằng trung bình có trọng số từ:
- predicted_labels: Kết quả từ CLIP model (weight: 0.4)
- env_result: Kết quả từ environmental factors (weight: 0.3)
- symptom_result: Kết quả từ symptoms (weight: 0.3)
"""
# Weight
ENV_WEIGHT = 0.3
SYMPTOM_WEIGHT = 0.2
# Dictionary để tích lũy scores cho mỗi disease/crop combination
label_scores = {}
# 1. Điểm từ CLIP model
for label in predicted_labels:
key = f"{label.crop_id}_{label.label}"
print(f"CLIP key: {key} score: {label.confidence}")
if key not in label_scores:
label_scores[key] = {
"crop_id": label.crop_id,
"label": label.label,
"total_score": 0,
"count": 0
}
label_scores[key]["total_score"] += label.confidence
label_scores[key]["count"] += 1
# 2. Điểm từ symptoms
for symptom in symptom_result:
disease = symptom.get("disease")
if disease and hasattr(disease, 'score'):
key = f"{crop_id}_{disease.id}"
print(f"Symptom key: {key} score: {disease.score}")
if key not in label_scores:
label_scores[key] = {
"crop_id": crop_id,
"label": disease.id,
"total_score": 0,
"count": 0
}
label_scores[key]["total_score"] += disease.score * SYMPTOM_WEIGHT * (1-label_scores[key]["total_score"])
# 3. Điểm từ environmental factors
for env in env_result:
disease = env.get("disease")
if disease and hasattr(disease, 'score'):
# Giả sử disease có thông tin về crop và label
key = f"{crop_id}_{disease.id}"
print(f"Env key: {key} score: {disease.score}")
if key not in label_scores:
label_scores[key] = {
"crop_id": crop_id,
"label": disease.id,
"total_score": 0,
"count": 0
}
label_scores[key]["total_score"] += disease.score * ENV_WEIGHT * (1-label_scores[key]["total_score"])
# Tạo final_labels từ kết quả tính toán
final_labels = []
for key, data in label_scores.items():
final_confidence = data["total_score"]
final_labels.append(PredictedLabel(
crop_id=data["crop_id"],
label=data["label"],
confidence=min(final_confidence, 1.0) # Đảm bảo không vượt quá 1.0
))
# Sắp xếp theo confidence giảm dần và lọc ngưỡng
final_labels.sort(key=lambda x: x.confidence, reverse=True)
print(final_labels)
return [label for label in final_labels if label.confidence > 0.1] # Lọc ngưỡng thấp
# TODO:
async def get_nodes_from_image(self, image: UploadFile):
try:
gemini = GeminiGenerator()
symptoms = self.models["data_mapper"].get_embedding_by_label("Symptom")
symptom_list = [f"- id:{node.id} - name:{node.name}" for node in symptoms]
symptom_list = "\n".join(symptom_list)
prompt = Template(EXTRACT_NODES_FROM_IMAGE_PROMPT).substitute(symptom_list=symptom_list)
image_content = image.file.read()
pil_image = Image.open(io.BytesIO(image_content)).convert('RGB')
ids = gemini.generate(prompt, image=pil_image)
ids = (json.loads(clean_text(ids.text)))["ids"]
print(ids)
nodes = []
for id in ids:
node = next((symptom for symptom in symptoms if symptom.id == id), None)
nodes.append(node)
return nodes
except Exception as e:
print(f"Error while extract knowledge entities from image: {str(e)}")
return []
async def __get_nodes_from_additional_info_async(self, additional_info: str, data_mapper: DataMapping):
entities = extract_entities(additional_info)
if not entities:
return []
tasks = []
for entity in entities:
task = asyncio.create_task(
data_mapper.get_top_result_by_text_async(entity.name, 3),
name=f"query_entity_{entity.name}"
)
tasks.append(task)
results = await asyncio.gather(*tasks, return_exceptions=True)
top_results: list[Node] = []
for i, result in enumerate(results):
if isinstance(result, Exception):
continue
for node in result:
top_results.append(node)
return top_results
def get_embedding_by_id_threadsafe(self, id):
# Mỗi thread tạo connection riêng
conn = sqlite3.connect(VECTOR_EMBEDDINGS_DB_PATH, check_same_thread=False)
cursor = conn.cursor()
try:
cursor.execute("SELECT * FROM embeddings WHERE e_index = ?", (id,))
result = cursor.fetchone()
return result
finally:
cursor.close() # Đóng connection sau khi dùng xong
conn.close()
async def retrieve_kg_text(self, request: KGQueryRequest):
try:
nodes = await self.get_nodes_from_text(request.additional_info)
kg: KnowledgeGraphUtils = self.models["knowledge_graph"]
env_task = asyncio.create_task(
kg.get_disease_from_env_factors(request.crop_id, nodes)
)
symptom_task = asyncio.create_task(
kg.get_disease_from_symptoms(request.crop_id, nodes)
)
env_results, symptom_results = await asyncio.gather(env_task, symptom_task)
best_label = request.context.predicted_labels[0].label
best_env_result = next((result for result in env_results if result["disease"].id == best_label), None)
best_env_result_str = str(best_env_result)
best_symptom_result = next((result for result in symptom_results if result["disease"].id == best_label), None)
best_symptom_result_str = str(best_symptom_result)
prompt1 = None
prompt2 = None
result1 = None
result2 = None
if best_env_result:
prompt1 = Template(GET_STATEMENT_FROM_ENV_FACTORS_KG).substitute(context=best_env_result_str)
if best_symptom_result:
prompt2 = Template(GET_STATEMENT_FROM_DISEASE_KG).substitute(context=best_symptom_result_str)
gemini = GeminiGenerator()
print(prompt1)
if prompt1:
result1 = gemini.generate(prompt1)
if prompt2:
result2 = gemini.generate(prompt2)
return {
"env_results": env_results,
"symptom_results": symptom_results,
"env_statement": result1.text if result1 else None,
"symptom_statement": result2.text if result2 else None
}
except Exception as e:
print(e)
raise e
async def get_nodes_from_text(self, text: str):
try:
gemini = GeminiGenerator()
node_list = [f" + id:{node.id}, name:{node.name}, description:{node.properties.get('description', '')}" for node in extracted_nodes]
prompt = Template(EXTRACT_NODES_FROM_TEXT_PROMPT).substitute(text=text, node_list=node_list)
ids = gemini.generate(prompt)
print(ids)
ids = (json.loads(clean_text(ids.text)))["ids"]
print(ids)
nodes = [next((node for node in extracted_nodes if node.id == id), None) for id in ids]
return nodes
except Exception as e:
print(e)
# async def get_all_nodes(self):
# try:
# kg: KnowledgeGraphUtils = self.models["knowledge_graph"]
# list_nodes = await kg.get_all_nodes()
# return [dict(node[0], **{"label": "Symptom"}) for node in list_nodes]
# except Exception as e:
# print(e)
# return []
async def get_caption(self, image: UploadFile):
try:
gemini = GeminiGenerator()
prompt = Template(GET_CAPTION_FROM_IMAGE_PROMPT).substitute(caption_list=unique_captions_by_disease)
image_content = image.file.read()
pil_image = Image.open(io.BytesIO(image_content)).convert('RGB')
return (json.loads(clean_text(gemini.generate(prompt, image=pil_image).text)))["caption"]
except Exception as e:
print(e)
return None
def get_predict_service(models = Depends(get_all_models)):
return PredictService(models)