plm_ola_internvl / speech_encoder.py
jjw0126's picture
Upload folder using huggingface_hub
35904d7 verified
import types
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import WhisperFeatureExtractor
import whisper
import torch
try:
torch.set_default_device("cpu")
except Exception:
pass
import accelerate
# from ola.model.speech_encoder.beats.BEATs import BEATsConfig, BEATs
from beats_model import BEATsConfig, BEATs
class WhisperWrappedEncoder:
@classmethod
def load(cls, model_config):
def replace_layer_norm(module):
from whisper.model import LayerNorm
for name, child in module.named_children():
if isinstance(child, LayerNorm):
# Check if any parameter is a meta tensor
has_meta = any(p.is_meta for p in child.parameters())
if has_meta:
# For meta tensors, create new layer norm with same shape
new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine)
else:
old_params = child.state_dict()
new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine)
new_layer_norm.load_state_dict(old_params)
setattr(module, name, new_layer_norm)
else:
replace_layer_norm(child)
# Load whisper model, handling both file paths and model names
speech_encoder_path = model_config.speech_encoder
# First try loading directly (works for both file paths and model names)
try:
encoder = whisper.load_model(name=speech_encoder_path, device='cpu').encoder
except (NotImplementedError, RuntimeError) as e:
if "meta tensor" in str(e):
# Meta tensor issue - load model without device specification
print(f"Detected meta tensor issue, using alternative loading approach...")
# Load checkpoint directly to avoid device issues
import os
if os.path.isfile(speech_encoder_path):
# Load from file
checkpoint = torch.load(speech_encoder_path, map_location='cpu')
# Create model from checkpoint
from whisper.model import ModelDimensions, Whisper
dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims)
# Load state dict without moving to device
model.load_state_dict(checkpoint["model_state_dict"])
# Get encoder without device movement
encoder = model.encoder
else:
# Try loading as model name without device
import whisper.model as whisper_model
# This is a fallback - may need adjustment based on actual model
raise RuntimeError(f"Cannot load model {speech_encoder_path} due to meta tensor issues")
else:
raise e
replace_layer_norm(encoder)
return encoder
class DualWrappedEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.whisper_model = self.load_whisper(config)
self.beats_model = self.load_beats(config)
def load_whisper(self, model_config):
def replace_layer_norm(module):
from whisper.model import LayerNorm
for name, child in module.named_children():
if isinstance(child, LayerNorm):
# Check if any parameter is a meta tensor
has_meta = any(p.is_meta for p in child.parameters())
if has_meta:
# For meta tensors, create new layer norm with same shape
new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine)
else:
old_params = child.state_dict()
new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine)
new_layer_norm.load_state_dict(old_params)
setattr(module, name, new_layer_norm)
else:
replace_layer_norm(child)
# Load whisper model, handling both file paths and model names
speech_encoder_path = model_config.speech_encoder
import torch
from whisper.model import Whisper, ModelDimensions
# 1) Load checkpoint to CPU (weights are real tensors here)
ckpt = torch.load("/data1/cxy/plm-v/modeling/cache/large-v3.pt", map_location="cpu")
dims = ModelDimensions(**ckpt["dims"])
# 2) Build the module skeleton, then MATERIALIZE tensors on CPU
model = Whisper(dims)
model.to_empty(device="cpu") # <-- crucial when meta is involved
# 3) Load weights
missing, unexpected = model.load_state_dict(ckpt["model_state_dict"], strict=True)
model.eval()
encoder = model.encoder
replace_layer_norm(encoder)
return encoder
def load_beats(self, model_config):
beats_path = model_config.music_encoder
print("Loading BEATs Model")
beats_ckpt = torch.load(beats_path, map_location='cpu')
beats_cfg = BEATsConfig(beats_ckpt['cfg'])
beats = BEATs(beats_cfg)
beats = beats.to_empty(device='cpu')
# Load state dict
beats.load_state_dict(beats_ckpt['model'], strict=True)
return beats
def forward(self, x, raw_wav=None, audio_padding_mask=None):
with torch.no_grad():
self.beats_model = self.beats_model.float()
speech_embeds = self.whisper_model(x)
try:
# 详细检查BEATs模型输入
raw_wav_float = raw_wav.float()
audio_embeds, _ = self.beats_model.extract_features(raw_wav_float, padding_mask=audio_padding_mask, feature_only=True)
except Exception as e:
audio_embeds = torch.zeros(speech_embeds.shape[0], speech_embeds.shape[1], 1024, device=speech_embeds.device, dtype=speech_embeds.dtype)
if audio_embeds.size(1) < speech_embeds.size(1):
audio_embeds = F.pad(audio_embeds, (0, 0, 0, speech_embeds.size(1) - audio_embeds.size(1)))
elif audio_embeds.size(1) > speech_embeds.size(1):
speech_embeds = F.pad(speech_embeds, (0, 0, 0, audio_embeds.size(1) - speech_embeds.size(1)))
speech_embeds = torch.cat((speech_embeds, audio_embeds), dim=-1)
speech_embeds = speech_embeds.to(torch.bfloat16)
return speech_embeds