hswift's picture
Update app.py
679b885 verified
import os
import logging
from starlette.background import BackgroundTask # <-- IMPORT THE FIX
# --- FIX FOR ALL PERMISSION ERRORS ---
# Set environment variables BEFORE importing torch or diffusers.
# This forces all underlying libraries (huggingface_hub, torch, etc.)
# to use a writable directory inside /tmp, avoiding any permission errors.
CACHE_DIR = "/tmp/huggingface_cache"
os.environ['HF_HOME'] = CACHE_DIR
os.environ['HF_HUB_CACHE'] = os.path.join(CACHE_DIR, 'hub')
os.environ['TORCH_HOME'] = os.path.join(CACHE_DIR, 'torch')
os.makedirs(os.path.join(CACHE_DIR, 'hub'), exist_ok=True)
os.makedirs(os.path.join(CACHE_DIR, 'torch'), exist_ok=True)
# Now it's safe to import the other libraries
import torch
import tempfile
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from diffusers import AudioLDMPipeline
from scipy.io.wavfile import write as write_wav
import numpy as np
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("app")
# --- App Setup ---
app = FastAPI()
# Allow all origins for CORS, useful for development
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- Pydantic Model for Request Body ---
class AudioRequest(BaseModel):
prompt: str
# --- Model Loading ---
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if "cuda" in device else torch.float32
logger.info(f"Using device: {device} with dtype: {torch_dtype}")
logger.info(f"Using model cache directory: {CACHE_DIR}")
pipe = None
try:
# Use the stable, recommended model
repo_id = "cvssp/audioldm-s-full-v2"
pipe = AudioLDMPipeline.from_pretrained(
repo_id,
torch_dtype=torch_dtype,
# cache_dir is still good practice but the environment variables are the real fix
cache_dir=CACHE_DIR
)
pipe = pipe.to(device)
logger.info(f"Successfully loaded model: {repo_id}")
except Exception as e:
logger.error(f"Failed to load the model: {e}", exc_info=True)
pipe = None # Ensure pipe is None if loading fails
# --- API Endpoint ---
@app.post("/generate-audio")
async def generate_audio_endpoint(request: AudioRequest):
if pipe is None:
raise HTTPException(status_code=503, detail="Model is not available. Check server logs for loading errors.")
prompt = request.prompt
logger.info(f"Generating audio for prompt: '{prompt}'")
temp_file_path = ""
try:
audio = pipe(
prompt,
num_inference_steps=200,
audio_length_in_s=5.0,
guidance_scale=7.0
).audios[0]
sample_rate = 44100
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
temp_file_path = temp_file.name
audio_int16 = (audio * 32767).astype(np.int16)
write_wav(temp_file_path, sample_rate, audio_int16)
logger.info(f"Audio saved to temporary file: {temp_file_path}")
# ### THIS IS THE FIX ###
# Create a background task to delete the file AFTER the response is sent.
cleanup_task = BackgroundTask(os.remove, temp_file_path)
return FileResponse(
path=temp_file_path,
media_type='audio/wav',
filename=f"{prompt[:50].replace(' ', '_')}.wav",
background=cleanup_task # Use the background task here
)
except Exception as e:
logger.error(f"Error during audio generation for prompt '{prompt}': {e}", exc_info=True)
if temp_file_path and os.path.exists(temp_file_path):
os.remove(temp_file_path) # Clean up if something else goes wrong
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
def read_root():
return {"status": "Audio generation API is running."}