Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import boto3 | |
| import json | |
| import soundfile as sf | |
| import tempfile | |
| from generator import load_csm_1b # β Only import the loader | |
| import whisper # or use faster-whisper if preferred | |
| # β Embedded Segment class here to fix import error | |
| class Segment: | |
| def __init__(self, text, speaker=0, audio=None): | |
| self.text = text | |
| self.speaker = speaker | |
| self.audio = audio | |
| # Load models once | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Loading Sesame CSM model on {device}...") | |
| csm = load_csm_1b(device=device) | |
| sample_rate = csm.sample_rate | |
| print("β Sesame CSM model loaded") | |
| print("Loading Whisper model...") | |
| whisper_model = whisper.load_model("base") | |
| print("β Whisper model loaded") | |
| # LLaMA SageMaker endpoint | |
| LLAMA_ENDPOINT = "Llama-3-2-3B-Instruct-streaming-endpoint" | |
| REGION = "us-east-2" | |
| def transcribe_audio(audio): | |
| sr, audio_np = audio | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: | |
| sf.write(f.name, audio_np, sr) | |
| result = whisper_model.transcribe(f.name) | |
| return result["text"] | |
| def generate_audio(text, speaker=0, max_audio_length_ms=10000): | |
| audio = csm.generate( | |
| text=text, | |
| speaker=speaker, | |
| context=[], | |
| max_audio_length_ms=max_audio_length_ms | |
| ) | |
| return (sample_rate, audio.cpu().numpy()) | |
| def query_llama(prompt): | |
| client = boto3.client("sagemaker-runtime", region_name=REGION) | |
| payload = { | |
| "inputs": f"[INST] {prompt.strip()} [/INST]" | |
| } | |
| response = client.invoke_endpoint( | |
| EndpointName=LLAMA_ENDPOINT, | |
| ContentType="application/json", | |
| Body=json.dumps(payload) | |
| ) | |
| result = json.loads(response["Body"].read().decode("utf-8")) | |
| return result[0].get("generated_text", "") if isinstance(result, list) else result.get("generated_text", "") | |
| # Unified function | |
| def multi_mode(mode, text_input, mic_input, speaker, max_audio_length_ms): | |
| if mode == "Text to Text": | |
| response = query_llama(text_input) | |
| return response, None | |
| elif mode == "Text to Speech": | |
| return None, generate_audio(text_input, speaker, max_audio_length_ms) | |
| elif mode == "Speech to Text": | |
| transcribed = transcribe_audio(mic_input) | |
| return transcribed, None | |
| elif mode == "Speech to Speech": | |
| transcribed = transcribe_audio(mic_input) | |
| response = query_llama(transcribed) | |
| return None, generate_audio(response, speaker, max_audio_length_ms) | |
| else: | |
| return "Invalid mode selected.", None | |
| # Gradio Interface | |
| modes = ["Text to Text", "Text to Speech", "Speech to Text", "Speech to Speech"] | |
| app = gr.Interface( | |
| fn=multi_mode, | |
| inputs=[ | |
| gr.Radio(choices=modes, value="Text to Speech", label="Choose Mode"), | |
| gr.Textbox(lines=2, placeholder="Type something...", label="Text Input"), | |
| gr.Audio(source="microphone", type="numpy", label="Microphone Input"), | |
| gr.Slider(minimum=0, maximum=10, value=0, label="Speaker ID"), | |
| gr.Slider(minimum=1000, maximum=20000, value=10000, step=500, label="Max Audio Length (ms)") | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="Text Output"), | |
| gr.Audio(label="Speech Output") | |
| ], | |
| title="ποΈ Sesame AI: Speech + Text Assistant", | |
| description="Choose your mode to talk or type with Sesame. Powered by Whisper, LLaMA, and CSM." | |
| ) | |
| app.launch() | |