Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| from starlette.background import BackgroundTask # <-- IMPORT THE FIX | |
| # --- FIX FOR ALL PERMISSION ERRORS --- | |
| # Set environment variables BEFORE importing torch or diffusers. | |
| # This forces all underlying libraries (huggingface_hub, torch, etc.) | |
| # to use a writable directory inside /tmp, avoiding any permission errors. | |
| CACHE_DIR = "/tmp/huggingface_cache" | |
| os.environ['HF_HOME'] = CACHE_DIR | |
| os.environ['HF_HUB_CACHE'] = os.path.join(CACHE_DIR, 'hub') | |
| os.environ['TORCH_HOME'] = os.path.join(CACHE_DIR, 'torch') | |
| os.makedirs(os.path.join(CACHE_DIR, 'hub'), exist_ok=True) | |
| os.makedirs(os.path.join(CACHE_DIR, 'torch'), exist_ok=True) | |
| # Now it's safe to import the other libraries | |
| import torch | |
| import tempfile | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import FileResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from diffusers import AudioLDMPipeline | |
| from scipy.io.wavfile import write as write_wav | |
| import numpy as np | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("app") | |
| # --- App Setup --- | |
| app = FastAPI() | |
| # Allow all origins for CORS, useful for development | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --- Pydantic Model for Request Body --- | |
| class AudioRequest(BaseModel): | |
| prompt: str | |
| # --- Model Loading --- | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.float16 if "cuda" in device else torch.float32 | |
| logger.info(f"Using device: {device} with dtype: {torch_dtype}") | |
| logger.info(f"Using model cache directory: {CACHE_DIR}") | |
| pipe = None | |
| try: | |
| # Use the stable, recommended model | |
| repo_id = "cvssp/audioldm-s-full-v2" | |
| pipe = AudioLDMPipeline.from_pretrained( | |
| repo_id, | |
| torch_dtype=torch_dtype, | |
| # cache_dir is still good practice but the environment variables are the real fix | |
| cache_dir=CACHE_DIR | |
| ) | |
| pipe = pipe.to(device) | |
| logger.info(f"Successfully loaded model: {repo_id}") | |
| except Exception as e: | |
| logger.error(f"Failed to load the model: {e}", exc_info=True) | |
| pipe = None # Ensure pipe is None if loading fails | |
| # --- API Endpoint --- | |
| async def generate_audio_endpoint(request: AudioRequest): | |
| if pipe is None: | |
| raise HTTPException(status_code=503, detail="Model is not available. Check server logs for loading errors.") | |
| prompt = request.prompt | |
| logger.info(f"Generating audio for prompt: '{prompt}'") | |
| temp_file_path = "" | |
| try: | |
| audio = pipe( | |
| prompt, | |
| num_inference_steps=200, | |
| audio_length_in_s=5.0, | |
| guidance_scale=7.0 | |
| ).audios[0] | |
| sample_rate = 44100 | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file: | |
| temp_file_path = temp_file.name | |
| audio_int16 = (audio * 32767).astype(np.int16) | |
| write_wav(temp_file_path, sample_rate, audio_int16) | |
| logger.info(f"Audio saved to temporary file: {temp_file_path}") | |
| # ### THIS IS THE FIX ### | |
| # Create a background task to delete the file AFTER the response is sent. | |
| cleanup_task = BackgroundTask(os.remove, temp_file_path) | |
| return FileResponse( | |
| path=temp_file_path, | |
| media_type='audio/wav', | |
| filename=f"{prompt[:50].replace(' ', '_')}.wav", | |
| background=cleanup_task # Use the background task here | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error during audio generation for prompt '{prompt}': {e}", exc_info=True) | |
| if temp_file_path and os.path.exists(temp_file_path): | |
| os.remove(temp_file_path) # Clean up if something else goes wrong | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def read_root(): | |
| return {"status": "Audio generation API is running."} |