""" Model initialization and inference logic for image generation. This module handles loading the diffusion model and provides functions for generating images from text prompts with error handling. """ import logging import random from typing import Tuple, Optional, Union import numpy as np import torch from diffusers import DiffusionPipeline from PIL import Image from config import MODEL_REPO_ID, MAX_SEED # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) class ModelManager: """Manages the diffusion model for image generation.""" def __init__(self): """Initialize the ModelManager and load the model.""" self.device = "cuda" if torch.cuda.is_available() else "cpu" self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 self.pipe = None def load_model(self) -> None: """ Load the diffusion model from the specified repository. Handles potential errors during model loading. """ try: logger.info(f"Loading model {MODEL_REPO_ID} on {self.device} with {self.torch_dtype}") self.pipe = DiffusionPipeline.from_pretrained( MODEL_REPO_ID, torch_dtype=self.torch_dtype ) self.pipe = self.pipe.to(self.device) logger.info("Model loaded successfully") except Exception as e: logger.error(f"Error loading model: {str(e)}") raise RuntimeError(f"Failed to load model: {str(e)}") def generate_image( self, prompt: str, negative_prompt: str = "", seed: int = 0, randomize_seed: bool = True, width: int = 1024, height: int = 1024, guidance_scale: float = 0.0, num_inference_steps: int = 2, progress_callback: Optional[callable] = None ) -> Tuple[Union[Image.Image, None], int]: """ Generate an image based on the provided prompt and parameters. Args: prompt: Text description of the desired image negative_prompt: Text description of what to avoid in the image seed: Random seed for reproducibility randomize_seed: Whether to use a random seed width: Width of the generated image height: Height of the generated image guidance_scale: How closely to follow the prompt num_inference_steps: Number of denoising steps progress_callback: Optional callback function for progress updates Returns: Tuple containing the generated image and the seed used """ if self.pipe is None: logger.error("Model not loaded. Call load_model() first.") return None, seed # Validate inputs if not prompt or prompt.strip() == "": logger.warning("Empty prompt provided, using default") prompt = "A beautiful landscape" # Handle seed randomization if randomize_seed: seed = random.randint(0, MAX_SEED) # Set up generator for reproducibility generator = torch.Generator(device=self.device).manual_seed(seed) try: logger.info(f"Generating image with prompt: '{prompt}'") # Generate the image result = self.pipe( prompt=prompt, negative_prompt=negative_prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, width=width, height=height, generator=generator, callback=progress_callback ) image = result.images[0] logger.info(f"Image generated successfully with seed {seed}") return image, seed except Exception as e: logger.error(f"Error generating image: {str(e)}") return None, seed