karagmercola's picture
Update app.py
858a352 verified
import gradio as gr
import torch
import numpy as np
import boto3
import json
import soundfile as sf
import tempfile
from generator import load_csm_1b # βœ… Only import the loader
import whisper # or use faster-whisper if preferred
# βœ… Embedded Segment class here to fix import error
class Segment:
def __init__(self, text, speaker=0, audio=None):
self.text = text
self.speaker = speaker
self.audio = audio
# Load models once
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading Sesame CSM model on {device}...")
csm = load_csm_1b(device=device)
sample_rate = csm.sample_rate
print("βœ… Sesame CSM model loaded")
print("Loading Whisper model...")
whisper_model = whisper.load_model("base")
print("βœ… Whisper model loaded")
# LLaMA SageMaker endpoint
LLAMA_ENDPOINT = "Llama-3-2-3B-Instruct-streaming-endpoint"
REGION = "us-east-2"
def transcribe_audio(audio):
sr, audio_np = audio
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
sf.write(f.name, audio_np, sr)
result = whisper_model.transcribe(f.name)
return result["text"]
def generate_audio(text, speaker=0, max_audio_length_ms=10000):
audio = csm.generate(
text=text,
speaker=speaker,
context=[],
max_audio_length_ms=max_audio_length_ms
)
return (sample_rate, audio.cpu().numpy())
def query_llama(prompt):
client = boto3.client("sagemaker-runtime", region_name=REGION)
payload = {
"inputs": f"[INST] {prompt.strip()} [/INST]"
}
response = client.invoke_endpoint(
EndpointName=LLAMA_ENDPOINT,
ContentType="application/json",
Body=json.dumps(payload)
)
result = json.loads(response["Body"].read().decode("utf-8"))
return result[0].get("generated_text", "") if isinstance(result, list) else result.get("generated_text", "")
# Unified function
def multi_mode(mode, text_input, mic_input, speaker, max_audio_length_ms):
if mode == "Text to Text":
response = query_llama(text_input)
return response, None
elif mode == "Text to Speech":
return None, generate_audio(text_input, speaker, max_audio_length_ms)
elif mode == "Speech to Text":
transcribed = transcribe_audio(mic_input)
return transcribed, None
elif mode == "Speech to Speech":
transcribed = transcribe_audio(mic_input)
response = query_llama(transcribed)
return None, generate_audio(response, speaker, max_audio_length_ms)
else:
return "Invalid mode selected.", None
# Gradio Interface
modes = ["Text to Text", "Text to Speech", "Speech to Text", "Speech to Speech"]
app = gr.Interface(
fn=multi_mode,
inputs=[
gr.Radio(choices=modes, value="Text to Speech", label="Choose Mode"),
gr.Textbox(lines=2, placeholder="Type something...", label="Text Input"),
gr.Audio(source="microphone", type="numpy", label="Microphone Input"),
gr.Slider(minimum=0, maximum=10, value=0, label="Speaker ID"),
gr.Slider(minimum=1000, maximum=20000, value=10000, step=500, label="Max Audio Length (ms)")
],
outputs=[
gr.Textbox(label="Text Output"),
gr.Audio(label="Speech Output")
],
title="πŸŽ™οΈ Sesame AI: Speech + Text Assistant",
description="Choose your mode to talk or type with Sesame. Powered by Whisper, LLaMA, and CSM."
)
app.launch()