Spaces:
Sleeping
Sleeping
Update streamlit_app.py
Browse files- streamlit_app.py +125 -53
streamlit_app.py
CHANGED
|
@@ -339,19 +339,59 @@ def transcribe_audio(audio_data, whisper_model, whisper_processor):
|
|
| 339 |
tmp_file.write(audio_data)
|
| 340 |
tmp_file_path = tmp_file.name
|
| 341 |
|
| 342 |
-
# Load audio using
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
try:
|
| 344 |
audio_array, sampling_rate = librosa.load(tmp_file_path, sr=16000, dtype=np.float32)
|
|
|
|
| 345 |
except Exception as e:
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
|
| 353 |
# Clean up temporary file
|
| 354 |
-
|
|
|
|
|
|
|
|
|
|
| 355 |
else:
|
| 356 |
audio_array = audio_data
|
| 357 |
sampling_rate = 16000
|
|
@@ -359,31 +399,46 @@ def transcribe_audio(audio_data, whisper_model, whisper_processor):
|
|
| 359 |
if hasattr(audio_array, 'astype'):
|
| 360 |
audio_array = audio_array.astype(np.float32)
|
| 361 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
# Ensure audio is normalized and in correct format
|
| 363 |
if isinstance(audio_array, np.ndarray):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
# Normalize audio to [-1, 1] range if needed
|
| 365 |
if np.max(np.abs(audio_array)) > 1.0:
|
| 366 |
audio_array = audio_array / np.max(np.abs(audio_array))
|
| 367 |
|
| 368 |
# Ensure float32 dtype
|
| 369 |
audio_array = audio_array.astype(np.float32)
|
|
|
|
|
|
|
| 370 |
|
| 371 |
# Process audio with Whisper
|
| 372 |
try:
|
|
|
|
| 373 |
# Try with language parameter first
|
| 374 |
-
input_features = whisper_processor(
|
| 375 |
-
audio_array,
|
| 376 |
-
sampling_rate=16000,
|
| 377 |
-
return_tensors="pt",
|
| 378 |
-
language="english" # Set default language to English
|
| 379 |
-
).input_features
|
| 380 |
-
except Exception as proc_error:
|
| 381 |
-
# Fallback without language parameter
|
| 382 |
input_features = whisper_processor(
|
| 383 |
audio_array,
|
| 384 |
sampling_rate=16000,
|
| 385 |
return_tensors="pt"
|
| 386 |
).input_features
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
|
| 388 |
# Get device and model info
|
| 389 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
@@ -394,53 +449,65 @@ def transcribe_audio(audio_data, whisper_model, whisper_processor):
|
|
| 394 |
|
| 395 |
# Generate transcription with error handling
|
| 396 |
try:
|
|
|
|
| 397 |
with torch.no_grad():
|
| 398 |
-
#
|
| 399 |
-
forced_decoder_ids = whisper_processor.get_decoder_prompt_ids(language="english", task="transcribe")
|
| 400 |
predicted_ids = whisper_model.generate(
|
| 401 |
input_features,
|
| 402 |
max_length=448, # Standard max length for Whisper
|
| 403 |
num_beams=1, # Faster generation
|
| 404 |
do_sample=False, # Deterministic output
|
| 405 |
-
|
|
|
|
| 406 |
)
|
|
|
|
| 407 |
except RuntimeError as e:
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
else:
|
| 423 |
-
|
|
|
|
| 424 |
except Exception as generation_error:
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
with torch.no_grad():
|
| 428 |
-
predicted_ids = whisper_model.generate(
|
| 429 |
-
input_features,
|
| 430 |
-
max_length=448,
|
| 431 |
-
num_beams=1,
|
| 432 |
-
do_sample=False
|
| 433 |
-
)
|
| 434 |
-
except Exception as final_error:
|
| 435 |
-
raise final_error
|
| 436 |
|
| 437 |
# Decode transcription
|
| 438 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
|
| 440 |
-
return transcription.strip()
|
| 441 |
-
|
| 442 |
except Exception as e:
|
| 443 |
-
st.error(f"Error transcribing audio: {e}")
|
| 444 |
logging.error(f"Transcription error: {e}")
|
| 445 |
return ""
|
| 446 |
|
|
@@ -1290,7 +1357,7 @@ def main():
|
|
| 1290 |
)
|
| 1291 |
if transcribed_text and transcribed_text != st.session_state.last_processed_message:
|
| 1292 |
st.session_state.last_processed_message = transcribed_text
|
| 1293 |
-
st.success(f"Transcribed: {transcribed_text}")
|
| 1294 |
# Add transcribed text to chat
|
| 1295 |
st.session_state.messages.append({"role": "user", "content": transcribed_text})
|
| 1296 |
|
|
@@ -1331,11 +1398,16 @@ def main():
|
|
| 1331 |
|
| 1332 |
# Trigger rerun to display the conversation
|
| 1333 |
st.rerun()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1334 |
else:
|
| 1335 |
-
# Reset states if
|
| 1336 |
st.session_state.waiting_for_input = True
|
| 1337 |
st.session_state.processing_complete = True
|
| 1338 |
-
st.
|
| 1339 |
|
| 1340 |
except Exception:
|
| 1341 |
# Fallback to file uploader
|
|
@@ -1408,7 +1480,7 @@ def main():
|
|
| 1408 |
# Reset states if transcription failed
|
| 1409 |
st.session_state.waiting_for_input = True
|
| 1410 |
st.session_state.processing_complete = True
|
| 1411 |
-
st.error("Could not transcribe audio. Please try again.")
|
| 1412 |
|
| 1413 |
# Show ready status when waiting for input
|
| 1414 |
if st.session_state.waiting_for_input and st.session_state.processing_complete:
|
|
|
|
| 339 |
tmp_file.write(audio_data)
|
| 340 |
tmp_file_path = tmp_file.name
|
| 341 |
|
| 342 |
+
# Load audio using multiple fallback methods
|
| 343 |
+
audio_array = None
|
| 344 |
+
sampling_rate = 16000
|
| 345 |
+
|
| 346 |
+
# Method 1: Try librosa
|
| 347 |
try:
|
| 348 |
audio_array, sampling_rate = librosa.load(tmp_file_path, sr=16000, dtype=np.float32)
|
| 349 |
+
st.info("✅ Audio loaded with librosa")
|
| 350 |
except Exception as e:
|
| 351 |
+
st.warning(f"Librosa failed: {e}")
|
| 352 |
+
|
| 353 |
+
# Method 2: Try soundfile
|
| 354 |
+
try:
|
| 355 |
+
audio_array, sampling_rate = sf.read(tmp_file_path)
|
| 356 |
+
if sampling_rate != 16000:
|
| 357 |
+
audio_array = librosa.resample(audio_array, orig_sr=sampling_rate, target_sr=16000)
|
| 358 |
+
sampling_rate = 16000
|
| 359 |
+
# Ensure float32 dtype
|
| 360 |
+
audio_array = audio_array.astype(np.float32)
|
| 361 |
+
st.info("✅ Audio loaded with soundfile")
|
| 362 |
+
except Exception as e2:
|
| 363 |
+
st.warning(f"Soundfile failed: {e2}")
|
| 364 |
+
|
| 365 |
+
# Method 3: Try scipy.io.wavfile
|
| 366 |
+
try:
|
| 367 |
+
sampling_rate, audio_array = wavfile.read(tmp_file_path)
|
| 368 |
+
# Convert to float32 and normalize
|
| 369 |
+
if audio_array.dtype == np.int16:
|
| 370 |
+
audio_array = audio_array.astype(np.float32) / 32768.0
|
| 371 |
+
elif audio_array.dtype == np.int32:
|
| 372 |
+
audio_array = audio_array.astype(np.float32) / 2147483648.0
|
| 373 |
+
else:
|
| 374 |
+
audio_array = audio_array.astype(np.float32)
|
| 375 |
+
|
| 376 |
+
# Resample if needed
|
| 377 |
+
if sampling_rate != 16000:
|
| 378 |
+
audio_array = librosa.resample(audio_array, orig_sr=sampling_rate, target_sr=16000)
|
| 379 |
+
sampling_rate = 16000
|
| 380 |
+
st.info("✅ Audio loaded with scipy.wavfile")
|
| 381 |
+
except Exception as e3:
|
| 382 |
+
st.error(f"All audio loading methods failed: {e3}")
|
| 383 |
+
# Clean up and return empty
|
| 384 |
+
try:
|
| 385 |
+
os.unlink(tmp_file_path)
|
| 386 |
+
except:
|
| 387 |
+
pass
|
| 388 |
+
return ""
|
| 389 |
|
| 390 |
# Clean up temporary file
|
| 391 |
+
try:
|
| 392 |
+
os.unlink(tmp_file_path)
|
| 393 |
+
except:
|
| 394 |
+
pass
|
| 395 |
else:
|
| 396 |
audio_array = audio_data
|
| 397 |
sampling_rate = 16000
|
|
|
|
| 399 |
if hasattr(audio_array, 'astype'):
|
| 400 |
audio_array = audio_array.astype(np.float32)
|
| 401 |
|
| 402 |
+
# Validate audio array
|
| 403 |
+
if audio_array is None or len(audio_array) == 0:
|
| 404 |
+
st.error("❌ Audio array is empty or invalid")
|
| 405 |
+
return ""
|
| 406 |
+
|
| 407 |
# Ensure audio is normalized and in correct format
|
| 408 |
if isinstance(audio_array, np.ndarray):
|
| 409 |
+
# Handle multi-channel audio by taking the first channel
|
| 410 |
+
if len(audio_array.shape) > 1:
|
| 411 |
+
audio_array = audio_array[:, 0] # Take first channel if stereo
|
| 412 |
+
|
| 413 |
+
# Check minimum audio length (at least 0.5 seconds)
|
| 414 |
+
min_length = int(0.5 * 16000) # 0.5 seconds at 16kHz
|
| 415 |
+
if len(audio_array) < min_length:
|
| 416 |
+
st.warning("⚠️ Audio is too short (less than 0.5 seconds). Please record a longer audio.")
|
| 417 |
+
return ""
|
| 418 |
+
|
| 419 |
# Normalize audio to [-1, 1] range if needed
|
| 420 |
if np.max(np.abs(audio_array)) > 1.0:
|
| 421 |
audio_array = audio_array / np.max(np.abs(audio_array))
|
| 422 |
|
| 423 |
# Ensure float32 dtype
|
| 424 |
audio_array = audio_array.astype(np.float32)
|
| 425 |
+
|
| 426 |
+
st.info(f"📊 Audio processed: {len(audio_array)/16000:.2f}s duration, shape: {audio_array.shape}")
|
| 427 |
|
| 428 |
# Process audio with Whisper
|
| 429 |
try:
|
| 430 |
+
st.info("🔄 Processing with Whisper...")
|
| 431 |
# Try with language parameter first
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
input_features = whisper_processor(
|
| 433 |
audio_array,
|
| 434 |
sampling_rate=16000,
|
| 435 |
return_tensors="pt"
|
| 436 |
).input_features
|
| 437 |
+
|
| 438 |
+
st.info("✅ Audio features extracted successfully")
|
| 439 |
+
except Exception as proc_error:
|
| 440 |
+
st.error(f"❌ Failed to process audio features: {proc_error}")
|
| 441 |
+
return ""
|
| 442 |
|
| 443 |
# Get device and model info
|
| 444 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 449 |
|
| 450 |
# Generate transcription with error handling
|
| 451 |
try:
|
| 452 |
+
st.info("🔄 Generating transcription...")
|
| 453 |
with torch.no_grad():
|
| 454 |
+
# Try simple generation first
|
|
|
|
| 455 |
predicted_ids = whisper_model.generate(
|
| 456 |
input_features,
|
| 457 |
max_length=448, # Standard max length for Whisper
|
| 458 |
num_beams=1, # Faster generation
|
| 459 |
do_sample=False, # Deterministic output
|
| 460 |
+
temperature=0.0, # Deterministic
|
| 461 |
+
use_cache=True
|
| 462 |
)
|
| 463 |
+
st.info("✅ Transcription generated successfully")
|
| 464 |
except RuntimeError as e:
|
| 465 |
+
st.warning(f"⚠️ First attempt failed: {e}")
|
| 466 |
+
if "dtype" in str(e).lower() or "float16" in str(e).lower():
|
| 467 |
+
try:
|
| 468 |
+
# Try forcing float32 for both input and model
|
| 469 |
+
st.info("🔄 Retrying with float32...")
|
| 470 |
+
input_features = input_features.float()
|
| 471 |
+
if hasattr(whisper_model, 'float'):
|
| 472 |
+
whisper_model = whisper_model.float()
|
| 473 |
+
with torch.no_grad():
|
| 474 |
+
predicted_ids = whisper_model.generate(
|
| 475 |
+
input_features,
|
| 476 |
+
max_length=448,
|
| 477 |
+
num_beams=1,
|
| 478 |
+
do_sample=False,
|
| 479 |
+
temperature=0.0,
|
| 480 |
+
use_cache=True
|
| 481 |
+
)
|
| 482 |
+
st.info("✅ Transcription generated with float32")
|
| 483 |
+
except Exception as e2:
|
| 484 |
+
st.error(f"❌ Float32 attempt failed: {e2}")
|
| 485 |
+
return ""
|
| 486 |
else:
|
| 487 |
+
st.error(f"❌ Generation failed: {e}")
|
| 488 |
+
return ""
|
| 489 |
except Exception as generation_error:
|
| 490 |
+
st.error(f"❌ Unexpected generation error: {generation_error}")
|
| 491 |
+
return ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 492 |
|
| 493 |
# Decode transcription
|
| 494 |
+
try:
|
| 495 |
+
transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
|
| 496 |
+
transcription = transcription.strip()
|
| 497 |
+
|
| 498 |
+
if not transcription:
|
| 499 |
+
st.warning("⚠️ Transcription is empty. The audio might be silent or unclear.")
|
| 500 |
+
return ""
|
| 501 |
+
|
| 502 |
+
st.success(f"✅ Transcription successful: '{transcription[:50]}{'...' if len(transcription) > 50 else ''}'")
|
| 503 |
+
return transcription
|
| 504 |
+
|
| 505 |
+
except Exception as decode_error:
|
| 506 |
+
st.error(f"❌ Failed to decode transcription: {decode_error}")
|
| 507 |
+
return ""
|
| 508 |
|
|
|
|
|
|
|
| 509 |
except Exception as e:
|
| 510 |
+
st.error(f"❌ Error transcribing audio: {e}")
|
| 511 |
logging.error(f"Transcription error: {e}")
|
| 512 |
return ""
|
| 513 |
|
|
|
|
| 1357 |
)
|
| 1358 |
if transcribed_text and transcribed_text != st.session_state.last_processed_message:
|
| 1359 |
st.session_state.last_processed_message = transcribed_text
|
| 1360 |
+
st.success(f"✅ Transcribed: {transcribed_text}")
|
| 1361 |
# Add transcribed text to chat
|
| 1362 |
st.session_state.messages.append({"role": "user", "content": transcribed_text})
|
| 1363 |
|
|
|
|
| 1398 |
|
| 1399 |
# Trigger rerun to display the conversation
|
| 1400 |
st.rerun()
|
| 1401 |
+
elif not transcribed_text:
|
| 1402 |
+
# Reset states if transcription failed or returned empty
|
| 1403 |
+
st.session_state.waiting_for_input = True
|
| 1404 |
+
st.session_state.processing_complete = True
|
| 1405 |
+
st.error("❌ Could not transcribe audio. Please ensure the audio is clear and contains speech, then try again.")
|
| 1406 |
else:
|
| 1407 |
+
# Reset states if duplicate message
|
| 1408 |
st.session_state.waiting_for_input = True
|
| 1409 |
st.session_state.processing_complete = True
|
| 1410 |
+
st.warning("⚠️ This audio was already processed. Please record a new message.")
|
| 1411 |
|
| 1412 |
except Exception:
|
| 1413 |
# Fallback to file uploader
|
|
|
|
| 1480 |
# Reset states if transcription failed
|
| 1481 |
st.session_state.waiting_for_input = True
|
| 1482 |
st.session_state.processing_complete = True
|
| 1483 |
+
st.error("❌ Could not transcribe audio. Please ensure the audio file is valid and contains speech, then try again.")
|
| 1484 |
|
| 1485 |
# Show ready status when waiting for input
|
| 1486 |
if st.session_state.waiting_for_input and st.session_state.processing_complete:
|