Spaces:
Sleeping
Sleeping
| import logging | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
| from app.core.config import get_settings | |
| from app.api.routes import router as api_router | |
| from app.models.crop_clip import CLIPModule | |
| from app.models.gemini_caller import GeminiGenerator | |
| from app.utils.data_mapping import DataMapping, SingletonModel | |
| from app.models.knowledge_graph import KnowledgeGraphUtils, Neo4jConnection | |
| import asyncio | |
| from concurrent.futures import ThreadPoolExecutor | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| settings = get_settings() | |
| class ModelLoader: | |
| def __init__(self): | |
| self.clip_model = None | |
| self.gemini_model = None | |
| self.sentence_transformer = None | |
| self.neo4j_connection = None | |
| def load_models(self): | |
| try: | |
| if settings.load_clip_model: | |
| logger.info("Loading CLIP model...") | |
| self.clip_model = CLIPModule() | |
| logger.info("CLIP model loaded successfully") | |
| if settings.load_gemini_model: | |
| logger.info("Loading Gemini model...") | |
| self.gemini_model = GeminiGenerator() | |
| logger.info("Gemini model loaded successfully") | |
| if settings.load_data_mapper: | |
| logger.info("Loading DataMapper model...") | |
| self.data_mapper = DataMapping() | |
| logger.info("DataMapper model loaded successfully") | |
| if settings.load_knowledge_graph: | |
| logger.info("Connecting to Knowledge Graph...") | |
| self.knowledge_graph = KnowledgeGraphUtils() | |
| logger.info("Knowledge Graph connection established") | |
| except Exception as e: | |
| logger.error(f"Failed to load models: {e}") | |
| raise | |
| def close(self): | |
| if self.neo4j_connection: | |
| logger.info("Closing Neo4j connection...") | |
| self.neo4j_connection.close() | |
| self.clip_model = None | |
| self.gemini_model = None | |
| self.sentence_transformer = None | |
| logger.info("Models released") | |
| # Lifespan event handler | |
| async def lifespan(app: FastAPI): | |
| loop = asyncio.get_event_loop() | |
| with ThreadPoolExecutor() as pool: | |
| await loop.run_in_executor(pool, app.state.model_loader.load_models) | |
| logger.info("Application startup complete") | |
| yield | |
| app.state.model_loader.close() | |
| logger.info("Application shutdown complete") | |
| app = FastAPI( | |
| title="Crop Diagnosis Knowledge Graph API", | |
| description="API for querying crop diagnosis knowledge graph using LangChain", | |
| version="1.0.0", | |
| debug=settings.debug, | |
| lifespan=lifespan | |
| ) | |
| app.state.model_loader = ModelLoader() | |
| app.include_router(api_router, prefix="/api") | |
| async def root(): | |
| return {"message": "Welcome to Crop Diagnosis Knowledge Graph API"} | |