leenag commited on
Commit
c0bf138
Β·
verified Β·
1 Parent(s): e226644

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -99
app.py CHANGED
@@ -2,146 +2,209 @@ import gradio as gr
2
  from transformers import pipeline
3
  import numpy as np
4
  import os
 
 
5
 
6
  # --- Configuration ---
7
- # Choose your model:
8
- # - "openai/whisper-tiny.en" (fastest, English-only, lower accuracy)
9
- # - "openai/whisper-base.en" (good balance for English)
10
- # - "openai/whisper-small.en"
11
- # - "openai/whisper-tiny" (multilingual, but tiny)
12
- # - "openai/whisper-base" (multilingual, good balance)
13
- # For Spaces, you might want a smaller model if you're on free tier CPU
14
- MODEL_NAME = os.getenv("ASR_MODEL", "openai/whisper-base.en") # Default to base.en
15
- DEVICE = "cuda" if os.getenv("USE_GPU", "false").lower() == "true" else "cpu" # Check for GPU availability
16
 
17
  # --- Global Variables ---
18
- # Load the ASR pipeline (this will download the model on first run)
19
- try:
20
- print(f"Loading ASR model: {MODEL_NAME} on device: {DEVICE}")
21
- # For whisper, task="automatic-speech-recognition" is fine
22
- # If using a GPU on Spaces, make sure your Space has GPU hardware assigned.
23
- # For CPU, this might be slow for larger models.
24
- asr_pipeline = pipeline(
25
- task="automatic-speech-recognition",
26
- model=MODEL_NAME,
27
- device=DEVICE if DEVICE == "cuda" else -1 # device=-1 for CPU with Transformers pipeline
28
- )
29
- print("ASR model loaded successfully.")
30
- except Exception as e:
31
- print(f"Error loading ASR model: {e}")
32
- asr_pipeline = None # Allow app to run to show error
33
-
34
- # --- Core Transcription Logic ---
35
- def transcribe_audio_chunk(new_chunk_audio, history_state):
36
- """
37
- Transcribes a new audio chunk and appends it to the history.
38
- new_chunk_audio: Tuple (sample_rate, numpy_array_of_audio_data) from gr.Audio
39
- history_state: Dictionary containing the accumulated transcription string.
40
- """
41
- if new_chunk_audio is None or asr_pipeline is None:
42
- return history_state["full_text"], history_state # No new audio or model not loaded
 
 
 
 
 
 
 
 
 
 
43
 
44
- sample_rate, audio_data = new_chunk_audio
 
 
 
 
 
45
 
46
- # Ensure audio_data is float32, as Whisper expects
47
- audio_data = audio_data.astype(np.float32) / np.iinfo(audio_data.dtype).max # Normalize if it's int
 
48
 
49
- # Check if audio_data is substantial enough (e.g., > 0.1 seconds)
50
- # This helps avoid processing tiny, empty chunks.
51
- if len(audio_data) < sample_rate * 0.2:
52
- return history_state["full_text"], history_state
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  try:
55
- # Transcribe the audio data
56
- # The pipeline expects a dictionary with "sampling_rate" and "raw" audio data (numpy array)
57
- # Or it can take a filepath. For streaming, raw data is better.
58
- transcription_result = asr_pipeline({"sampling_rate": sample_rate, "raw": audio_data})
59
- new_text = transcription_result["text"].strip()
60
-
61
- if new_text:
62
- history_state["full_text"] += new_text + " "
63
- print(f"New chunk: '{new_text}' | Full: '{history_state['full_text'][:100]}...'") # Log for debugging
64
- except Exception as e:
65
- print(f"Error during transcription: {e}")
66
- # Optionally, append an error message to the transcription
67
- # history_state["full_text"] += f"[Error: {e}] "
68
- pass # Continue even if one chunk fails
69
 
70
- return history_state["full_text"], history_state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- # --- Gradio UI ---
73
- with gr.Blocks(title="Live Transcription") as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  gr.Markdown(
75
  f"""
76
- # πŸŽ™οΈ Live Speech-to-Text with Hugging Face Whisper
77
- Speak into your microphone. The transcription will appear below in real-time.
78
  Using model: `{MODEL_NAME}` on device: `{DEVICE}`.
79
- Note: There's a slight delay due to processing. For best results, speak clearly.
80
  """
81
  )
82
 
83
- if asr_pipeline is None:
84
- gr.Markdown("## ⚠️ Error: ASR Model Not Loaded. Check logs. ⚠️")
85
 
86
- # State to store the full transcription history
87
- # Initialize with a dictionary structure
88
  transcription_history = gr.State({"full_text": ""})
89
 
90
  with gr.Row():
91
- # Audio input: streaming from microphone
92
- # 'type="numpy"' gives (sample_rate, data_array)
93
- # 'streaming=True' enables continuous audio capture
94
  audio_input = gr.Audio(
95
  sources=["microphone"],
96
  type="numpy",
97
  streaming=True,
98
- label="Speak Here (Streaming Active)",
99
- # waveform_options=gr.WaveformOptions(show_controls=True) # Optional: show audio controls
100
  )
101
-
102
- # Text output for the live transcription
103
  transcription_output = gr.Textbox(
104
- label="Live Transcription",
105
- lines=15,
106
- interactive=False, # User shouldn't edit this directly
107
- show_copy_button=True
108
  )
109
 
110
- # Connect the audio input's streaming data to the transcription function
111
- # The 'stream' event is triggered periodically with new audio data.
112
- # 'every=1' means the function is called every 1 second with the audio collected during that second.
113
- # You can adjust 'every' (e.g., 0.5 for faster updates but more processing, 2 for less frequent).
114
  audio_input.stream(
115
- fn=transcribe_audio_chunk,
116
  inputs=[audio_input, transcription_history],
117
  outputs=[transcription_output, transcription_history],
118
- every=1 # Process audio chunks every 1 second.
119
  )
120
 
121
- # Button to clear the transcription
122
- def clear_transcription(current_state):
 
123
  current_state["full_text"] = ""
124
- print("Transcription cleared.")
125
- return "", current_state # Clear textbox and update state
126
 
127
- clear_button = gr.Button("Clear Transcription")
128
  clear_button.click(
129
- fn=clear_transcription,
130
  inputs=[transcription_history],
131
  outputs=[transcription_output, transcription_history]
132
  )
 
133
 
134
- gr.Markdown(
135
- """
136
- ---
137
- Built with [Gradio](https://gradio.app) and [Hugging Face Transformers](https://huggingface.co/transformers).
138
- Model: [OpenAI Whisper](https://huggingface.co/models?search=openai/whisper)
139
- """
140
- )
141
-
142
- # To run locally (optional, usually not needed for HF Spaces if app.py is the entry point)
143
  if __name__ == "__main__":
144
- # You can set environment variables here for local testing if you want
145
  # os.environ["ASR_MODEL"] = "openai/whisper-tiny.en"
146
  # os.environ["USE_GPU"] = "False"
147
- demo.queue().launch(debug=True, share=False) # Use queue for better handling of concurrent users
 
 
2
  from transformers import pipeline
3
  import numpy as np
4
  import os
5
+ import torch
6
+ import torchaudio # For VAD
7
 
8
  # --- Configuration ---
9
+ MODEL_NAME = os.getenv("ASR_MODEL", "openai/whisper-base.en")
10
+ DEVICE = "cuda" if torch.cuda.is_available() and os.getenv("USE_GPU", "false").lower() == "true" else "cpu"
11
+ print(f"Using device: {DEVICE}")
 
 
 
 
 
 
12
 
13
  # --- Global Variables ---
14
+ asr_pipeline = None
15
+ vad_model = None
16
+ vad_utils = None
17
+ audio_buffer = [] # To accumulate audio chunks
18
+ MAX_BUFFER_SECONDS = 10 # Max audio to buffer before forcing transcription
19
+ SILENCE_THRESHOLD_SECONDS = 1.5 # How long silence before processing speech segment
20
+
21
+ # --- Load Models ---
22
+ def load_models():
23
+ global asr_pipeline, vad_model, vad_utils
24
+ try:
25
+ print(f"Loading ASR model: {MODEL_NAME} on device: {DEVICE}")
26
+ asr_pipeline = pipeline(
27
+ task="automatic-speech-recognition",
28
+ model=MODEL_NAME,
29
+ device=DEVICE if DEVICE == "cuda" else -1
30
+ )
31
+ print("ASR model loaded successfully.")
32
+
33
+ print("Loading Silero VAD model...")
34
+ # Silero VAD model itself is small and runs on CPU efficiently
35
+ vad_model, vad_utils_tuple = torch.hub.load(repo_or_dir='snakers4/silero-vad',
36
+ model='silero_vad',
37
+ force_reload=False, # Set to True if you have issues
38
+ onnx=True) # Use ONNX for better CPU performance
39
+ (get_speech_timestamps,
40
+ save_audio,
41
+ read_audio,
42
+ VADIterator,
43
+ collect_chunks) = vad_utils_tuple
44
+ vad_utils = {
45
+ "get_speech_timestamps": get_speech_timestamps,
46
+ "VADIterator": VADIterator
47
+ }
48
+ print("Silero VAD model loaded successfully.")
49
 
50
+ except Exception as e:
51
+ print(f"Error loading models: {e}")
52
+ if asr_pipeline is None: print("ASR pipeline failed to load.")
53
+ if vad_model is None: print("VAD model failed to load.")
54
+
55
+ load_models() # Load models at startup
56
 
57
+ # --- Core Transcription Logic with VAD ---
58
+ def transcribe_with_vad(new_chunk_audio, history_state):
59
+ global audio_buffer
60
 
61
+ if new_chunk_audio is None or asr_pipeline is None or vad_model is None:
62
+ return history_state.get("full_text", ""), history_state
 
 
63
 
64
+ sample_rate, audio_data = new_chunk_audio
65
+ audio_data_float32 = audio_data.astype(np.float32) / np.iinfo(audio_data.dtype).max
66
+
67
+ # Append to buffer
68
+ audio_buffer.append(audio_data_float32)
69
+
70
+ # Check buffer length; if too short, wait for more audio
71
+ current_buffer_duration = sum(len(chunk) / sample_rate for chunk in audio_buffer)
72
+
73
+ # If buffer is empty or too short, just return current state
74
+ if not audio_buffer or current_buffer_duration < 0.2: # Minimum duration to process
75
+ return history_state.get("full_text", ""), history_state
76
+
77
+ # Concatenate buffer for VAD processing
78
+ full_audio_np = np.concatenate(audio_buffer)
79
+ full_audio_tensor = torch.from_numpy(full_audio_np).float()
80
+
81
+ # Use VAD to find speech timestamps
82
+ # We're looking for the *end* of speech segments
83
+ # This is a simplified approach: we process if VAD detects no speech in the latest part
84
+ # or if the buffer gets too long.
85
  try:
86
+ # For simplicity, let's analyze the last N seconds for silence
87
+ # A more robust VADIterator approach would be better for continuous streaming
88
+ # but is more complex to manage with Gradio's chunking.
89
+
90
+ # Let's try a simpler VAD: check if the last chunk contains speech
91
+ # For a more robust solution, use VADIterator or process the whole buffer
92
+ speech_timestamps = vad_utils["get_speech_timestamps"](
93
+ full_audio_tensor,
94
+ vad_model,
95
+ sampling_rate=sample_rate,
96
+ min_silence_duration_ms=500 # ms of silence to consider a break
97
+ )
 
 
98
 
99
+ # Heuristic: if speech_timestamps is empty for the latest chunk,
100
+ # OR if the buffer is long, OR if there's a significant pause
101
+ process_now = False
102
+ transcribed_text_segment = ""
103
+
104
+ if not speech_timestamps: # If no speech detected in the current combined buffer
105
+ if current_buffer_duration > SILENCE_THRESHOLD_SECONDS: # and we have enough audio to assume it's silence after speech
106
+ process_now = True
107
+ elif current_buffer_duration > MAX_BUFFER_SECONDS: # Buffer is too long, process it
108
+ process_now = True
109
+ else:
110
+ # If speech is detected, check if the end of the last speech segment is significantly before the end of the buffer
111
+ # This indicates a pause after speech.
112
+ if speech_timestamps:
113
+ last_speech_end_s = speech_timestamps[-1]['end'] / sample_rate
114
+ if current_buffer_duration - last_speech_end_s > SILENCE_THRESHOLD_SECONDS:
115
+ process_now = True
116
+
117
+ if process_now and full_audio_np.any(): # Ensure there's actual audio data
118
+ print(f"Processing {current_buffer_duration:.2f}s of buffered audio.")
119
+ # Transcribe the entire current buffer
120
+ transcription_result = asr_pipeline(
121
+ {"sampling_rate": sample_rate, "raw": full_audio_np.copy()}, # Send a copy
122
+ # You can add whisper specific args here if needed e.g. chunk_length_s for long-form
123
+ # generate_kwargs={"task": "transcribe", "language": "<|en|>"} # for multilingual models
124
+ )
125
+ new_text = transcription_result["text"].strip()
126
+
127
+ if new_text:
128
+ transcribed_text_segment = new_text + " "
129
+ history_state["full_text"] = history_state.get("full_text", "") + transcribed_text_segment
130
+ print(f"VAD processed: '{new_text}'")
131
+
132
+ audio_buffer = [] # Clear buffer after processing
133
 
134
+ except Exception as e:
135
+ print(f"Error during VAD/transcription: {e}")
136
+ # Fallback: transcribe accumulated buffer if error, then clear
137
+ if audio_buffer:
138
+ try:
139
+ full_audio_fallback = np.concatenate(audio_buffer)
140
+ if full_audio_fallback.any():
141
+ transcription_result = asr_pipeline(
142
+ {"sampling_rate": sample_rate, "raw": full_audio_fallback.copy()}
143
+ )
144
+ new_text = transcription_result["text"].strip()
145
+ if new_text:
146
+ history_state["full_text"] = history_state.get("full_text", "") + new_text + " "
147
+ print(f"Fallback processed: '{new_text}'")
148
+ except Exception as fallback_e:
149
+ print(f"Error during fallback transcription: {fallback_e}")
150
+ audio_buffer = [] # Clear buffer
151
+
152
+ return history_state.get("full_text", ""), history_state
153
+
154
+ # --- Gradio UI (largely the same, just point to new function and manage state) ---
155
+ with gr.Blocks(title="Live Transcription with VAD") as demo:
156
  gr.Markdown(
157
  f"""
158
+ # πŸŽ™οΈ Live Speech-to-Text with VAD & Hugging Face Whisper
159
+ Speak into your microphone. Transcription will appear after speech segments.
160
  Using model: `{MODEL_NAME}` on device: `{DEVICE}`.
161
+ VAD: Silero VAD
162
  """
163
  )
164
 
165
+ if asr_pipeline is None or vad_model is None:
166
+ gr.Markdown("## ⚠️ Error: Models Not Loaded. Check logs. ⚠️")
167
 
 
 
168
  transcription_history = gr.State({"full_text": ""})
169
 
170
  with gr.Row():
 
 
 
171
  audio_input = gr.Audio(
172
  sources=["microphone"],
173
  type="numpy",
174
  streaming=True,
175
+ label="Speak Here (Streaming Active with VAD)",
 
176
  )
 
 
177
  transcription_output = gr.Textbox(
178
+ label="Live Transcription", lines=15, interactive=False, show_copy_button=True
 
 
 
179
  )
180
 
181
+ # Adjust 'every' based on how frequently you want to check the VAD buffer
182
+ # Smaller 'every' means more frequent checks, potentially more responsive VAD
183
+ # but also more frequent function calls.
 
184
  audio_input.stream(
185
+ fn=transcribe_with_vad,
186
  inputs=[audio_input, transcription_history],
187
  outputs=[transcription_output, transcription_history],
188
+ every=0.5 # Check buffer and VAD every 0.5 seconds
189
  )
190
 
191
+ def clear_transcription_state(current_state):
192
+ global audio_buffer
193
+ audio_buffer = [] # Also clear the audio buffer
194
  current_state["full_text"] = ""
195
+ print("Transcription and audio buffer cleared.")
196
+ return "", current_state
197
 
198
+ clear_button = gr.Button("Clear Transcription & Buffer")
199
  clear_button.click(
200
+ fn=clear_transcription_state,
201
  inputs=[transcription_history],
202
  outputs=[transcription_output, transcription_history]
203
  )
204
+ gr.Markdown("---")
205
 
 
 
 
 
 
 
 
 
 
206
  if __name__ == "__main__":
 
207
  # os.environ["ASR_MODEL"] = "openai/whisper-tiny.en"
208
  # os.environ["USE_GPU"] = "False"
209
+ # load_models() # Ensure models are loaded if running locally
210
+ demo.queue().launch(debug=True, share=False)