import os import random import sys import asyncio import tempfile from typing import Sequence, Mapping, Any, Union import torch import gradio as gr from PIL import Image import numpy as np import spaces from huggingface_hub import hf_hub_download # Download required models at startup def download_models(): """Download all required models from HuggingFace Hub""" # Download models as specified try: print("📥 Downloading FLUX Kontext checkpoint...") hf_hub_download( repo_id="black-forest-labs/FLUX.1-Kontext-dev", filename="flux1-kontext-dev.safetensors", local_dir="models/checkpoints" ) print("✅ FLUX Kontext checkpoint downloaded") except Exception as e: print(f"❌ Error downloading FLUX checkpoint: {e}") try: print("📥 Downloading VAE model...") hf_hub_download( repo_id="black-forest-labs/FLUX.1-Kontext-dev", filename="ae.safetensors", local_dir="models/vae" ) print("✅ VAE model downloaded") except Exception as e: print(f"❌ Error downloading VAE: {e}") try: print("📥 Downloading CLIP text encoder...") hf_hub_download( repo_id="DoozyWo/Kontext_Clip_model", filename="model.safetensors", local_dir="models/text_encoders" ) print("✅ CLIP text encoder downloaded") except Exception as e: print(f"❌ Error downloading CLIP text encoder: {e}") try: print("📥 Downloading T5 text encoder...") hf_hub_download( repo_id="comfyanonymous/flux_text_encoders", filename="t5xxl_fp8_e4m3fn.safetensors", local_dir="models/text_encoders" ) print("✅ T5 text encoder downloaded") except Exception as e: print(f"❌ Error downloading T5 text encoder: {e}") try: print("📥 Downloading Avatar LoRA...") hf_hub_download( repo_id="DoozyWo/Kontext_avatar_LoRA", filename="Avataar_LoRA_000003000.safetensors", local_dir="models/loras" ) print("✅ Avatar LoRA downloaded") except Exception as e: print(f"❌ Error downloading Avatar LoRA: {e}") print("✅ Model downloads completed!") # Download models on import download_models() def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any: """Returns the value at the given index of a sequence or mapping.""" try: return obj[index] except KeyError: return obj["result"][index] def find_path(name: str, path: str = None) -> str: """Recursively looks for a path starting from the given directory.""" if path is None: path = os.getcwd() if name in os.listdir(path): path_name = os.path.join(path, name) print(f"{name} found: {path_name}") return path_name parent_directory = os.path.dirname(path) if parent_directory == path: return None return find_path(name, parent_directory) def add_comfyui_directory_to_sys_path() -> None: """Add ComfyUI to the sys.path""" comfyui_path = find_path("ComfyUI") if comfyui_path and os.path.isdir(comfyui_path): sys.path.append(comfyui_path) print(f"'{comfyui_path}' added to sys.path") def add_extra_model_paths() -> None: """Parse the optional extra_model_paths.yaml file and add paths to sys.path.""" try: from main import load_extra_path_config except ImportError: try: from utils.extra_config import load_extra_path_config except ImportError: print("Could not import load_extra_path_config") return extra_model_paths = find_path("extra_model_paths.yaml") if extra_model_paths: load_extra_path_config(extra_model_paths) else: print("Could not find the extra_model_paths config file.") # Initialize ComfyUI add_comfyui_directory_to_sys_path() add_extra_model_paths() async def import_custom_nodes() -> None: """Import and initialize ComfyUI custom nodes.""" import execution from nodes import init_extra_nodes import server # Create event loop if none exists try: loop = asyncio.get_event_loop() except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) # Initialize server and nodes server_instance = server.PromptServer(loop) execution.PromptQueue(server_instance) # Await the async function await init_extra_nodes() # Import NODE_CLASS_MAPPINGS after ComfyUI is set up from nodes import NODE_CLASS_MAPPINGS # Global initialization _initialized = False _model_loaders = None async def initialize_models(): """Initialize and preload models for faster inference.""" global _initialized, _model_loaders if _initialized: return _model_loaders await import_custom_nodes() # Use no_grad instead of inference_mode for better compatibility with torch.no_grad(): # Initialize all node classes dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]() vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]() checkpointloadersimple = NODE_CLASS_MAPPINGS["CheckpointLoaderSimple"]() loraloadermodelonly = NODE_CLASS_MAPPINGS["LoraLoaderModelOnly"]() # Check what CLIP models are available text_encoders_path = "models/text_encoders" available_files = [] if os.path.exists(text_encoders_path): available_files = os.listdir(text_encoders_path) print(f"Available text encoder files: {available_files}") # Try different CLIP loading approaches try: # First try: Use the expected model names dualcliploader_184 = dualcliploader.load_clip( clip_name1="model.safetensors", clip_name2="t5xxl_fp8_e4m3fn.safetensors", type="flux", device="default", ) except Exception as e: print(f"First CLIP load attempt failed: {e}") try: # Second try: Use alternative names dualcliploader_184 = dualcliploader.load_clip( clip_name1="clip_l.safetensors", clip_name2="t5xxl_fp8_e4m3fn.safetensors", type="flux", device="default", ) except Exception as e2: print(f"Second CLIP load attempt failed: {e2}") # Third try: Download and use standard FLUX text encoders print("📥 Downloading standard FLUX text encoders as fallback...") try: hf_hub_download( repo_id="comfyanonymous/flux_text_encoders", filename="clip_l.safetensors", local_dir="models/text_encoders" ) hf_hub_download( repo_id="comfyanonymous/flux_text_encoders", filename="t5xxl_fp16.safetensors", local_dir="models/text_encoders" ) dualcliploader_184 = dualcliploader.load_clip( clip_name1="clip_l.safetensors", clip_name2="t5xxl_fp16.safetensors", type="flux", device="default", ) except Exception as e3: print(f"Fallback CLIP download failed: {e3}") raise e3 vaeloader_39 = vaeloader.load_vae(vae_name="ae.safetensors") checkpointloadersimple_188 = checkpointloadersimple.load_checkpoint( ckpt_name="flux1-kontext-dev.safetensors" ) loraloadermodelonly_186 = loraloadermodelonly.load_lora_model_only( lora_name="Avataar_LoRA_000003000.safetensors", strength_model=1, model=get_value_at_index(checkpointloadersimple_188, 0), ) # Store all loaded models _model_loaders = { 'clip': dualcliploader_184, 'vae': vaeloader_39, 'checkpoint': checkpointloadersimple_188, 'lora_model': loraloadermodelonly_186 } # Load models to GPU for faster inference with better error handling try: from comfy import model_management # Collect valid models more safely valid_models = [] for loader_name, loader in _model_loaders.items(): try: if loader and len(loader) > 0: model_obj = loader[0] if hasattr(model_obj, 'patcher') and model_obj.patcher is not None: if not isinstance(model_obj.patcher, dict): valid_models.append(model_obj.patcher) elif not isinstance(model_obj, dict): valid_models.append(model_obj) except Exception as e: print(f"Warning: Could not process model {loader_name}: {e}") continue if valid_models: print(f"Loading {len(valid_models)} models to GPU...") model_management.load_models_gpu(valid_models) print("✅ Models loaded to GPU successfully") else: print("⚠️ No valid models found for GPU loading") except Exception as e: print(f"⚠️ Warning: Could not load models to GPU: {e}") print("Models will run on CPU/default device") _initialized = True return _model_loaders @spaces.GPU(duration=60) def generate_image(input_image, custom_prompt=""): """Synchronous wrapper for the async generate_image function - main entry point.""" async def _generate(): return await generate_image_async(input_image, custom_prompt) # Run the async function in the event loop try: loop = asyncio.get_event_loop() if loop.is_running(): # If loop is already running, create a new task import concurrent.futures with concurrent.futures.ThreadPoolExecutor() as executor: future = executor.submit(asyncio.run, _generate()) return future.result() else: return loop.run_until_complete(_generate()) except RuntimeError: # No event loop exists, create one return asyncio.run(_generate()) def generate_image_sync(input_image, custom_prompt=""): """Alternative synchronous wrapper - calls the main generate_image function.""" return generate_image(input_image, custom_prompt) async def generate_image_async(input_image, custom_prompt=""): """Transform an input image using Avatar LoRA with FLUX Kontext model.""" if input_image is None: return None, "Please provide an input image." try: # Initialize models model_loaders = await initialize_models() # Use no_grad instead of inference_mode for better compatibility with torch.no_grad(): # Force garbage collection before starting import gc gc.collect() torch.cuda.empty_cache() if torch.cuda.is_available() else None # Initialize node classes loadimage = NODE_CLASS_MAPPINGS["LoadImage"]() imagescaletototalpixels = NODE_CLASS_MAPPINGS["ImageScaleToTotalPixels"]() vaeencode = NODE_CLASS_MAPPINGS["VAEEncode"]() cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]() referencelatent = NODE_CLASS_MAPPINGS["ReferenceLatent"]() fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]() ksampler = NODE_CLASS_MAPPINGS["KSampler"]() vaedecode = NODE_CLASS_MAPPINGS["VAEDecode"]() # Save input image temporarily temp_dir = tempfile.mkdtemp() temp_image_path = os.path.join(temp_dir, "input_image.jpg") input_image.save(temp_image_path) try: # Load and process input image loadimage_133 = loadimage.load_image(image=temp_image_path) # Scale image - detach tensors to avoid version tracking issues image_tensor = get_value_at_index(loadimage_133, 0) if isinstance(image_tensor, torch.Tensor): image_tensor = image_tensor.detach().clone() imagescaletototalpixels_187 = imagescaletototalpixels.upscale( upscale_method="bicubic", megapixels=1, image=image_tensor, ) # Encode image - ensure tensor is detached scaled_image = get_value_at_index(imagescaletototalpixels_187, 0) if isinstance(scaled_image, torch.Tensor): scaled_image = scaled_image.detach().clone() vaeencode_124 = vaeencode.encode( pixels=scaled_image, vae=get_value_at_index(model_loaders['vae'], 0), ) # Base prompt for Avatar transformation base_prompt = "Turn this into a photorealistic Na'vi character from Avatar, with blue bioluminescent skin, large eyes, and set in the glowing jungle of Pandora." # Combine with custom prompt if provided if custom_prompt.strip(): full_prompt = f"{base_prompt} {custom_prompt.strip()}" else: full_prompt = base_prompt # Encode text prompts cliptextencode_181 = cliptextencode.encode( text=full_prompt, clip=get_value_at_index(model_loaders['clip'], 0), ) cliptextencode_182 = cliptextencode.encode( text="", clip=get_value_at_index(model_loaders['clip'], 0) ) # Generate Avatar transformation referencelatent_176 = referencelatent.append( conditioning=get_value_at_index(cliptextencode_181, 0), latent=get_value_at_index(vaeencode_124, 0), ) fluxguidance_179 = fluxguidance.append( guidance=4.5, conditioning=get_value_at_index(referencelatent_176, 0) ) # Use random seed for variety import random random_seed = random.randint(0, 2**32 - 1) ksampler_178 = ksampler.sample( seed=42, steps=25, cfg=1, sampler_name="euler", scheduler="simple", denoise=1, model=get_value_at_index(model_loaders['lora_model'], 0), positive=get_value_at_index(fluxguidance_179, 0), negative=get_value_at_index(cliptextencode_182, 0), latent_image=get_value_at_index(vaeencode_124, 0), ) vaedecode_177 = vaedecode.decode( samples=get_value_at_index(ksampler_178, 0), vae=get_value_at_index(model_loaders['vae'], 0), ) # Get the result image and properly handle tensor conversion result_images = get_value_at_index(vaedecode_177, 0) # Convert tensor to PIL Image with proper detachment if isinstance(result_images, torch.Tensor): # Detach and clone to avoid version tracking issues image_tensor = result_images.detach().clone().squeeze(0) # Move to CPU if on GPU if image_tensor.is_cuda: image_tensor = image_tensor.cpu() # Convert to numpy image_np = image_tensor.numpy() image_np = np.clip(image_np, 0.0, 1.0) image_np = (image_np * 255).astype(np.uint8) if len(image_np.shape) == 3 and image_np.shape[-1] == 3: result_image = Image.fromarray(image_np, 'RGB') elif len(image_np.shape) == 3: result_image = Image.fromarray(image_np[:,:,0], 'L').convert('RGB') else: result_image = Image.fromarray(image_np, 'L').convert('RGB') else: result_image = result_images # Force cleanup of tensors del loadimage_133, imagescaletototalpixels_187, vaeencode_124 del cliptextencode_181, cliptextencode_182, referencelatent_176 del fluxguidance_179, ksampler_178, vaedecode_177 gc.collect() torch.cuda.empty_cache() if torch.cuda.is_available() else None return result_image, "✨ Avatar transformation complete!" finally: # Cleanup temporary files try: os.remove(temp_image_path) os.rmdir(temp_dir) except: pass except Exception as e: # Force cleanup on error import gc gc.collect() torch.cuda.empty_cache() if torch.cuda.is_available() else None return None, f"❌ Error: {str(e)}" def create_gradio_interface(): """Create the Gradio interface.""" custom_css = """ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;700&display=swap'); .gradio-container { background: linear-gradient(135deg, #0f172a 0%, #1e293b 100%); min-height: 100vh; font-family: 'Inter', sans-serif; padding: 2rem; } .main-header { text-align: center; font-size: 3.5rem; font-weight: 700; background: linear-gradient(135deg, #06b6d4, #8b5cf6); -webkit-background-clip: text; -webkit-text-fill-color: transparent; margin-bottom: 1rem; } .main-description { text-align: center; font-size: 1.3rem; color: #cbd5e1; max-width: 800px; margin: 0 auto 3rem auto; line-height: 1.6; } .image-container { background: rgba(15, 23, 42, 0.6); border-radius: 20px; padding: 2rem; border: 2px solid rgba(6, 182, 212, 0.3); backdrop-filter: blur(10px); } .big-image-upload { border: 2px dashed rgba(6, 182, 212, 0.5) !important; border-radius: 16px !important; background: rgba(15, 23, 42, 0.3) !important; min-height: 500px !important; transition: all 0.3s ease !important; } .big-image-upload:hover { border-color: rgba(6, 182, 212, 0.8) !important; background: rgba(6, 182, 212, 0.05) !important; } .section-title { color: #06b6d4; font-weight: 600; font-size: 1.5rem; text-align: center; margin-bottom: 1.5rem; } .custom-textbox textarea { background: rgba(15, 23, 42, 0.6) !important; border: 2px solid rgba(6, 182, 212, 0.3) !important; border-radius: 12px !important; color: #e2e8f0 !important; font-size: 1.1rem !important; padding: 1rem !important; min-height: 120px !important; } .custom-textbox textarea:focus { border-color: rgba(6, 182, 212, 0.8) !important; background: rgba(15, 23, 42, 0.8) !important; } .transform-button { background: linear-gradient(135deg, #06b6d4, #3b82f6) !important; color: white !important; font-weight: 600 !important; font-size: 1.3rem !important; padding: 1.2rem 3rem !important; border-radius: 12px !important; border: none !important; box-shadow: 0 8px 25px rgba(6, 182, 212, 0.4) !important; transition: all 0.3s ease !important; width: 100% !important; margin-top: 1.5rem !important; } .transform-button:hover { transform: translateY(-2px) !important; box-shadow: 0 12px 35px rgba(6, 182, 212, 0.6) !important; } .status-display { background: rgba(15, 23, 42, 0.6); border-radius: 12px; padding: 1.5rem; margin-top: 1rem; border: 1px solid rgba(6, 182, 212, 0.2); min-height: 60px; display: flex; align-items: center; justify-content: center; } .status-success { color: #10b981 !important; font-weight: 600 !important; font-size: 1.2rem !important; } .status-error { color: #ef4444 !important; font-weight: 600 !important; font-size: 1.2rem !important; } .status-processing { color: #06b6d4 !important; font-weight: 600 !important; font-size: 1.2rem !important; } @media (max-width: 768px) { .main-header { font-size: 2.5rem; } .big-image-upload { min-height: 400px !important; } .gradio-container { padding: 1rem; } } """ with gr.Blocks(css=custom_css, title="Avatar Transformation Studio") as interface: # Header gr.HTML("""
📷 Use clear, well-lit photos • 👤 Front-facing works best • 🎨 High resolution recommended • ✨ Be creative with prompts