Spaces:
Sleeping
Sleeping
| import json | |
| import io | |
| from string import Template | |
| from fastapi import Depends, UploadFile | |
| import asyncio | |
| from PIL import Image | |
| import sqlite3 | |
| from app.api.dto.kg_query import KGQueryRequest, QueryContext, PredictedLabel | |
| from app.core.dependencies import get_all_models | |
| from app.core.type import Node | |
| from app.models.crop_clip import EfficientNetModule | |
| from app.models.gemini_caller import GeminiGenerator | |
| from app.models.knowledge_graph import KnowledgeGraphUtils | |
| from app.utils.constant import EXTRACTED_NODES | |
| from app.utils.data_mapping import VECTOR_EMBEDDINGS_DB_PATH, DataMapping | |
| from app.utils.extract_entity import clean_text, extract_entities | |
| from app.utils.prompt import EXTRACT_NODES_FROM_IMAGE_PROMPT, EXTRACT_NODES_FROM_TEXT_PROMPT, GET_CAPTION_FROM_IMAGE_PROMPT, GET_STATEMENT_FROM_DISEASE_KG, GET_STATEMENT_FROM_ENV_FACTORS_KG, INSTRUCTION | |
| from app.utils.main import unique_captions_by_disease | |
| class CustomJSONEncoder(json.JSONEncoder): | |
| def default(self, obj): | |
| if hasattr(obj, 'model_dump'): # Pydantic v2 BaseModel | |
| return obj.model_dump() | |
| elif hasattr(obj, 'dict'): # Pydantic v1 BaseModel | |
| return obj.dict() | |
| elif isinstance(obj, (list, tuple)): | |
| return [self.default(item) if hasattr(item, 'model_dump') or hasattr(item, 'dict') else item for item in obj] | |
| return super().default(obj) | |
| def convert_to_json_serializable(obj): | |
| """Convert objects containing Node instances to JSON serializable format""" | |
| try: | |
| if hasattr(obj, 'model_dump'): # Pydantic v2 BaseModel | |
| return obj.model_dump() | |
| elif hasattr(obj, 'dict'): # Pydantic v1 BaseModel | |
| return obj.dict() | |
| elif isinstance(obj, list): | |
| return [convert_to_json_serializable(item) for item in obj] | |
| elif isinstance(obj, dict): | |
| return {key: convert_to_json_serializable(value) for key, value in obj.items()} | |
| elif isinstance(obj, tuple): | |
| return [convert_to_json_serializable(item) for item in obj] | |
| elif obj is None: | |
| return None | |
| else: | |
| # Try to convert basic types | |
| try: | |
| json.dumps(obj) # Test if it's JSON serializable | |
| return obj | |
| except (TypeError, ValueError): | |
| # If it's not serializable, convert to string as fallback | |
| print(f"Warning: Converting non-serializable object {type(obj)} to string: {obj}") | |
| return str(obj) | |
| except Exception as e: | |
| print(f"Error in convert_to_json_serializable for object {type(obj)}: {e}") | |
| return str(obj) | |
| extracted_nodes = [ | |
| Node( | |
| id=node['id'], | |
| label=node['label'], | |
| name=node['name'], | |
| properties={'description': node['description']}, | |
| score=None | |
| ) for node in EXTRACTED_NODES | |
| ] | |
| class PredictService: | |
| def __init__(self, models): | |
| self.models = models | |
| async def predict_image(self, image: UploadFile): | |
| efficientnet_model: EfficientNetModule = self.models["efficientnet_model"] | |
| image_content = image.file.read() | |
| pil_image = Image.open(Image.io.BytesIO(image_content)).convert('RGB') | |
| return efficientnet_model.predict_image(pil_image) | |
| async def retrieve_kg(self, request: KGQueryRequest): | |
| try: | |
| kg: KnowledgeGraphUtils = self.models["knowledge_graph"] | |
| if not request.context: | |
| request.context = QueryContext() | |
| if request.crop_id: | |
| request.context.crop_id = request.crop_id | |
| if request.additional_info: | |
| additional_nodes = await self.__get_nodes_from_additional_info_async( | |
| request.additional_info, self.models["data_mapper"] | |
| ) | |
| if request.context.nodes is None: | |
| request.context.nodes = [] | |
| request.context.nodes = request.context.nodes + additional_nodes | |
| for node in request.context.nodes: | |
| if node.score is None: | |
| node.score = 0.9 | |
| env_task = asyncio.create_task( | |
| kg.get_disease_from_env_factors(request.context.crop_id, request.context.nodes) | |
| ) | |
| symptom_task = asyncio.create_task( | |
| kg.get_disease_from_symptoms(request.context.crop_id, request.context.nodes) | |
| ) | |
| env_results, symptom_results = await asyncio.gather(env_task, symptom_task) | |
| context = request.context | |
| context.nodes.extend([env_result["disease"] for env_result in env_results]) | |
| context.nodes.extend([symptom_result["disease"] for symptom_result in symptom_results]) | |
| print(context.nodes) | |
| context.nodes.sort(key=lambda x: x.score, reverse=True) | |
| # Tính toán final_labels bằng trung bình có trọng số | |
| if context.predicted_labels: | |
| print("Got predicted labels") | |
| context.final_labels = self.calculate_final_labels( | |
| context.predicted_labels, | |
| env_results, | |
| symptom_results, | |
| context.crop_id | |
| ) | |
| return { | |
| "context": context, | |
| "env_results": env_results, | |
| "symptom_results": symptom_results | |
| } | |
| except Exception as e: | |
| print(e) | |
| raise e | |
| def calculate_final_labels(self, predicted_labels, env_result, symptom_result, crop_id): | |
| """ | |
| Tính toán final_labels bằng trung bình có trọng số từ: | |
| - predicted_labels: Kết quả từ CLIP model (weight: 0.4) | |
| - env_result: Kết quả từ environmental factors (weight: 0.3) | |
| - symptom_result: Kết quả từ symptoms (weight: 0.3) | |
| """ | |
| # Weight | |
| ENV_WEIGHT = 0.3 | |
| SYMPTOM_WEIGHT = 0.2 | |
| # Dictionary để tích lũy scores cho mỗi disease/crop combination | |
| label_scores = {} | |
| # 1. Điểm từ CLIP model | |
| for label in predicted_labels: | |
| key = f"{label.crop_id}_{label.label}" | |
| print(f"CLIP key: {key} score: {label.confidence}") | |
| if key not in label_scores: | |
| label_scores[key] = { | |
| "crop_id": label.crop_id, | |
| "label": label.label, | |
| "total_score": 0, | |
| "count": 0 | |
| } | |
| label_scores[key]["total_score"] += label.confidence | |
| label_scores[key]["count"] += 1 | |
| # 2. Điểm từ symptoms | |
| for symptom in symptom_result: | |
| disease = symptom.get("disease") | |
| if disease and hasattr(disease, 'score'): | |
| key = f"{crop_id}_{disease.id}" | |
| print(f"Symptom key: {key} score: {disease.score}") | |
| if key not in label_scores: | |
| label_scores[key] = { | |
| "crop_id": crop_id, | |
| "label": disease.id, | |
| "total_score": 0, | |
| "count": 0 | |
| } | |
| label_scores[key]["total_score"] += disease.score * SYMPTOM_WEIGHT * (1-label_scores[key]["total_score"]) | |
| # 3. Điểm từ environmental factors | |
| for env in env_result: | |
| disease = env.get("disease") | |
| if disease and hasattr(disease, 'score'): | |
| # Giả sử disease có thông tin về crop và label | |
| key = f"{crop_id}_{disease.id}" | |
| print(f"Env key: {key} score: {disease.score}") | |
| if key not in label_scores: | |
| label_scores[key] = { | |
| "crop_id": crop_id, | |
| "label": disease.id, | |
| "total_score": 0, | |
| "count": 0 | |
| } | |
| label_scores[key]["total_score"] += disease.score * ENV_WEIGHT * (1-label_scores[key]["total_score"]) | |
| # Tạo final_labels từ kết quả tính toán | |
| final_labels = [] | |
| for key, data in label_scores.items(): | |
| final_confidence = data["total_score"] | |
| final_labels.append(PredictedLabel( | |
| crop_id=data["crop_id"], | |
| label=data["label"], | |
| confidence=min(final_confidence, 1.0) # Đảm bảo không vượt quá 1.0 | |
| )) | |
| # Sắp xếp theo confidence giảm dần và lọc ngưỡng | |
| final_labels.sort(key=lambda x: x.confidence, reverse=True) | |
| print(final_labels) | |
| return [label for label in final_labels if label.confidence > 0.1] # Lọc ngưỡng thấp | |
| # TODO: | |
| async def get_nodes_from_image(self, image: UploadFile): | |
| try: | |
| gemini = GeminiGenerator() | |
| symptoms = self.models["data_mapper"].get_embedding_by_label("Symptom") | |
| symptom_list = [f"- id:{node.id} - name:{node.name}" for node in symptoms] | |
| symptom_list = "\n".join(symptom_list) | |
| prompt = Template(EXTRACT_NODES_FROM_IMAGE_PROMPT).substitute(symptom_list=symptom_list) | |
| image_content = image.file.read() | |
| pil_image = Image.open(io.BytesIO(image_content)).convert('RGB') | |
| ids = gemini.generate(prompt, image=pil_image) | |
| ids = (json.loads(clean_text(ids.text)))["ids"] | |
| print(ids) | |
| nodes = [] | |
| for id in ids: | |
| node = next((symptom for symptom in symptoms if symptom.id == id), None) | |
| nodes.append(node) | |
| return nodes | |
| except Exception as e: | |
| print(f"Error while extract knowledge entities from image: {str(e)}") | |
| return [] | |
| async def __get_nodes_from_additional_info_async(self, additional_info: str, data_mapper: DataMapping): | |
| entities = extract_entities(additional_info) | |
| if not entities: | |
| return [] | |
| tasks = [] | |
| for entity in entities: | |
| task = asyncio.create_task( | |
| data_mapper.get_top_result_by_text_async(entity.name, 3), | |
| name=f"query_entity_{entity.name}" | |
| ) | |
| tasks.append(task) | |
| results = await asyncio.gather(*tasks, return_exceptions=True) | |
| top_results: list[Node] = [] | |
| for i, result in enumerate(results): | |
| if isinstance(result, Exception): | |
| continue | |
| for node in result: | |
| top_results.append(node) | |
| return top_results | |
| def get_embedding_by_id_threadsafe(self, id): | |
| # Mỗi thread tạo connection riêng | |
| conn = sqlite3.connect(VECTOR_EMBEDDINGS_DB_PATH, check_same_thread=False) | |
| cursor = conn.cursor() | |
| try: | |
| cursor.execute("SELECT * FROM embeddings WHERE e_index = ?", (id,)) | |
| result = cursor.fetchone() | |
| return result | |
| finally: | |
| cursor.close() # Đóng connection sau khi dùng xong | |
| conn.close() | |
| async def retrieve_kg_text(self, request: KGQueryRequest): | |
| try: | |
| nodes = await self.get_nodes_from_text(request.additional_info) | |
| kg: KnowledgeGraphUtils = self.models["knowledge_graph"] | |
| env_task = asyncio.create_task( | |
| kg.get_disease_from_env_factors(request.crop_id, nodes) | |
| ) | |
| symptom_task = asyncio.create_task( | |
| kg.get_disease_from_symptoms(request.crop_id, nodes) | |
| ) | |
| env_results, symptom_results = await asyncio.gather(env_task, symptom_task) | |
| best_label = request.context.predicted_labels[0].label | |
| best_env_result = next((result for result in env_results if result["disease"].id == best_label), None) | |
| best_env_result_str = str(best_env_result) | |
| best_symptom_result = next((result for result in symptom_results if result["disease"].id == best_label), None) | |
| best_symptom_result_str = str(best_symptom_result) | |
| prompt1 = None | |
| prompt2 = None | |
| result1 = None | |
| result2 = None | |
| if best_env_result: | |
| prompt1 = Template(GET_STATEMENT_FROM_ENV_FACTORS_KG).substitute(context=best_env_result_str) | |
| if best_symptom_result: | |
| prompt2 = Template(GET_STATEMENT_FROM_DISEASE_KG).substitute(context=best_symptom_result_str) | |
| gemini = GeminiGenerator() | |
| print(prompt1) | |
| if prompt1: | |
| result1 = gemini.generate(prompt1) | |
| if prompt2: | |
| result2 = gemini.generate(prompt2) | |
| return { | |
| "env_results": env_results, | |
| "symptom_results": symptom_results, | |
| "env_statement": result1.text if result1 else None, | |
| "symptom_statement": result2.text if result2 else None | |
| } | |
| except Exception as e: | |
| print(e) | |
| raise e | |
| async def get_nodes_from_text(self, text: str): | |
| try: | |
| gemini = GeminiGenerator() | |
| node_list = [f" + id:{node.id}, name:{node.name}, description:{node.properties.get('description', '')}" for node in extracted_nodes] | |
| prompt = Template(EXTRACT_NODES_FROM_TEXT_PROMPT).substitute(text=text, node_list=node_list) | |
| ids = gemini.generate(prompt) | |
| print(ids) | |
| ids = (json.loads(clean_text(ids.text)))["ids"] | |
| print(ids) | |
| nodes = [next((node for node in extracted_nodes if node.id == id), None) for id in ids] | |
| return nodes | |
| except Exception as e: | |
| print(e) | |
| # async def get_all_nodes(self): | |
| # try: | |
| # kg: KnowledgeGraphUtils = self.models["knowledge_graph"] | |
| # list_nodes = await kg.get_all_nodes() | |
| # return [dict(node[0], **{"label": "Symptom"}) for node in list_nodes] | |
| # except Exception as e: | |
| # print(e) | |
| # return [] | |
| async def get_caption(self, image: UploadFile): | |
| try: | |
| gemini = GeminiGenerator() | |
| prompt = Template(GET_CAPTION_FROM_IMAGE_PROMPT).substitute(caption_list=unique_captions_by_disease) | |
| image_content = image.file.read() | |
| pil_image = Image.open(io.BytesIO(image_content)).convert('RGB') | |
| return (json.loads(clean_text(gemini.generate(prompt, image=pil_image).text)))["caption"] | |
| except Exception as e: | |
| print(e) | |
| return None | |
| def get_predict_service(models = Depends(get_all_models)): | |
| return PredictService(models) | |