""" Tests for the image generation application. This module contains unit tests for the various components of the application. """ import unittest from unittest.mock import MagicMock, patch import os from pathlib import Path import numpy as np from PIL import Image # Import application modules from config import MODEL_REPO_ID, MAX_SEED from model import ModelManager from utils import save_image, format_generation_info, GenerationHistory class TestConfig(unittest.TestCase): """Test the configuration module.""" def test_config_values(self): """Test that configuration values are properly set.""" from config import ( MODEL_REPO_ID, DEFAULT_GUIDANCE_SCALE, DEFAULT_INFERENCE_STEPS, DEFAULT_WIDTH, DEFAULT_HEIGHT, MAX_IMAGE_SIZE, EXAMPLE_PROMPTS ) self.assertEqual(MODEL_REPO_ID, "stabilityai/sdxl-turbo") self.assertEqual(DEFAULT_GUIDANCE_SCALE, 0.0) self.assertEqual(DEFAULT_INFERENCE_STEPS, 2) self.assertEqual(DEFAULT_WIDTH, 1024) self.assertEqual(DEFAULT_HEIGHT, 1024) self.assertEqual(MAX_IMAGE_SIZE, 1024) self.assertIsInstance(EXAMPLE_PROMPTS, list) self.assertTrue(len(EXAMPLE_PROMPTS) > 0) class TestModelManager(unittest.TestCase): """Test the ModelManager class.""" @patch('model.DiffusionPipeline') def test_init(self, mock_pipeline): """Test ModelManager initialization.""" manager = ModelManager() self.assertIn(manager.device, ["cuda", "cpu"]) self.assertIsNone(manager.pipe) @patch('model.DiffusionPipeline.from_pretrained') def test_load_model(self, mock_from_pretrained): """Test model loading.""" # Setup mock mock_pipe = MagicMock() mock_from_pretrained.return_value = mock_pipe mock_pipe.to.return_value = mock_pipe # Test loading manager = ModelManager() manager.load_model() # Verify calls mock_from_pretrained.assert_called_once_with( MODEL_REPO_ID, torch_dtype=manager.torch_dtype ) mock_pipe.to.assert_called_once_with(manager.device) self.assertEqual(manager.pipe, mock_pipe) @patch('model.DiffusionPipeline') def test_generate_image_with_randomize(self, mock_pipeline): """Test image generation with randomized seed.""" # Setup mock manager = ModelManager() manager.pipe = MagicMock() mock_image = MagicMock() manager.pipe.return_value = MagicMock(images=[mock_image]) # Test generation with randomized seed prompt = "test prompt" image, seed = manager.generate_image( prompt=prompt, randomize_seed=True ) # Verify result self.assertEqual(image, mock_image) self.assertGreaterEqual(seed, 0) self.assertLessEqual(seed, MAX_SEED) class TestUtils(unittest.TestCase): """Test utility functions.""" def setUp(self): """Set up test environment.""" # Create a test image self.test_image = Image.new('RGB', (100, 100), color='red') # Ensure test output directory exists from utils import OUTPUTS_DIR self.test_outputs_dir = OUTPUTS_DIR self.test_outputs_dir.mkdir(exist_ok=True) def test_save_image(self): """Test image saving functionality.""" prompt = "test image prompt" filepath = save_image(self.test_image, prompt) # Check that file was created self.assertTrue(os.path.exists(filepath)) self.assertTrue(filepath.endswith(".png")) # Clean up os.remove(filepath) def test_format_generation_info(self): """Test generation info formatting.""" prompt = "test prompt" negative_prompt = "test negative" seed = 42 width = 512 height = 512 guidance_scale = 7.5 steps = 30 info = format_generation_info( prompt, negative_prompt, seed, width, height, guidance_scale, steps ) # Check that all parameters are included in the info string self.assertIn(prompt, info) self.assertIn(negative_prompt, info) self.assertIn(str(seed), info) self.assertIn(str(width), info) self.assertIn(str(height), info) self.assertIn(str(guidance_scale), info) self.assertIn(str(steps), info) def test_generation_history(self): """Test the GenerationHistory class.""" history = GenerationHistory(max_history=3) # Test empty history self.assertEqual(len(history.history), 0) self.assertEqual(history.get_latest(), []) # Add entries for i in range(5): history.add( self.test_image, f"prompt {i}", f"negative {i}", i, 512, 512, 7.5, 30 ) # Check that history is limited to max_history self.assertEqual(len(history.history), 3) # Check that entries are in correct order (newest last) latest = history.get_latest(1)[0] self.assertEqual(latest["prompt"], "prompt 4") # Test clear history.clear() self.assertEqual(len(history.history), 0) if __name__ == '__main__': unittest.main()