import gradio as gr import torch import numpy as np import boto3 import json import soundfile as sf import tempfile from generator import load_csm_1b # ✅ Only import the loader import whisper # or use faster-whisper if preferred # ✅ Embedded Segment class here to fix import error class Segment: def __init__(self, text, speaker=0, audio=None): self.text = text self.speaker = speaker self.audio = audio # Load models once device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading Sesame CSM model on {device}...") csm = load_csm_1b(device=device) sample_rate = csm.sample_rate print("✅ Sesame CSM model loaded") print("Loading Whisper model...") whisper_model = whisper.load_model("base") print("✅ Whisper model loaded") # LLaMA SageMaker endpoint LLAMA_ENDPOINT = "Llama-3-2-3B-Instruct-streaming-endpoint" REGION = "us-east-2" def transcribe_audio(audio): sr, audio_np = audio with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: sf.write(f.name, audio_np, sr) result = whisper_model.transcribe(f.name) return result["text"] def generate_audio(text, speaker=0, max_audio_length_ms=10000): audio = csm.generate( text=text, speaker=speaker, context=[], max_audio_length_ms=max_audio_length_ms ) return (sample_rate, audio.cpu().numpy()) def query_llama(prompt): client = boto3.client("sagemaker-runtime", region_name=REGION) payload = { "inputs": f"[INST] {prompt.strip()} [/INST]" } response = client.invoke_endpoint( EndpointName=LLAMA_ENDPOINT, ContentType="application/json", Body=json.dumps(payload) ) result = json.loads(response["Body"].read().decode("utf-8")) return result[0].get("generated_text", "") if isinstance(result, list) else result.get("generated_text", "") # Unified function def multi_mode(mode, text_input, mic_input, speaker, max_audio_length_ms): if mode == "Text to Text": response = query_llama(text_input) return response, None elif mode == "Text to Speech": return None, generate_audio(text_input, speaker, max_audio_length_ms) elif mode == "Speech to Text": transcribed = transcribe_audio(mic_input) return transcribed, None elif mode == "Speech to Speech": transcribed = transcribe_audio(mic_input) response = query_llama(transcribed) return None, generate_audio(response, speaker, max_audio_length_ms) else: return "Invalid mode selected.", None # Gradio Interface modes = ["Text to Text", "Text to Speech", "Speech to Text", "Speech to Speech"] app = gr.Interface( fn=multi_mode, inputs=[ gr.Radio(choices=modes, value="Text to Speech", label="Choose Mode"), gr.Textbox(lines=2, placeholder="Type something...", label="Text Input"), gr.Audio(source="microphone", type="numpy", label="Microphone Input"), gr.Slider(minimum=0, maximum=10, value=0, label="Speaker ID"), gr.Slider(minimum=1000, maximum=20000, value=10000, step=500, label="Max Audio Length (ms)") ], outputs=[ gr.Textbox(label="Text Output"), gr.Audio(label="Speech Output") ], title="🎙️ Sesame AI: Speech + Text Assistant", description="Choose your mode to talk or type with Sesame. Powered by Whisper, LLaMA, and CSM." ) app.launch()