import types import torch import torch.nn as nn import torch.nn.functional as F from transformers import WhisperFeatureExtractor import whisper import torch try: torch.set_default_device("cpu") except Exception: pass import accelerate # from ola.model.speech_encoder.beats.BEATs import BEATsConfig, BEATs from beats_model import BEATsConfig, BEATs class WhisperWrappedEncoder: @classmethod def load(cls, model_config): def replace_layer_norm(module): from whisper.model import LayerNorm for name, child in module.named_children(): if isinstance(child, LayerNorm): # Check if any parameter is a meta tensor has_meta = any(p.is_meta for p in child.parameters()) if has_meta: # For meta tensors, create new layer norm with same shape new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine) else: old_params = child.state_dict() new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine) new_layer_norm.load_state_dict(old_params) setattr(module, name, new_layer_norm) else: replace_layer_norm(child) # Load whisper model, handling both file paths and model names speech_encoder_path = model_config.speech_encoder # First try loading directly (works for both file paths and model names) try: encoder = whisper.load_model(name=speech_encoder_path, device='cpu').encoder except (NotImplementedError, RuntimeError) as e: if "meta tensor" in str(e): # Meta tensor issue - load model without device specification print(f"Detected meta tensor issue, using alternative loading approach...") # Load checkpoint directly to avoid device issues import os if os.path.isfile(speech_encoder_path): # Load from file checkpoint = torch.load(speech_encoder_path, map_location='cpu') # Create model from checkpoint from whisper.model import ModelDimensions, Whisper dims = ModelDimensions(**checkpoint["dims"]) model = Whisper(dims) # Load state dict without moving to device model.load_state_dict(checkpoint["model_state_dict"]) # Get encoder without device movement encoder = model.encoder else: # Try loading as model name without device import whisper.model as whisper_model # This is a fallback - may need adjustment based on actual model raise RuntimeError(f"Cannot load model {speech_encoder_path} due to meta tensor issues") else: raise e replace_layer_norm(encoder) return encoder class DualWrappedEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config self.whisper_model = self.load_whisper(config) self.beats_model = self.load_beats(config) def load_whisper(self, model_config): def replace_layer_norm(module): from whisper.model import LayerNorm for name, child in module.named_children(): if isinstance(child, LayerNorm): # Check if any parameter is a meta tensor has_meta = any(p.is_meta for p in child.parameters()) if has_meta: # For meta tensors, create new layer norm with same shape new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine) else: old_params = child.state_dict() new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine) new_layer_norm.load_state_dict(old_params) setattr(module, name, new_layer_norm) else: replace_layer_norm(child) # Load whisper model, handling both file paths and model names speech_encoder_path = model_config.speech_encoder import torch from whisper.model import Whisper, ModelDimensions # 1) Load checkpoint to CPU (weights are real tensors here) ckpt = torch.load("/data1/cxy/plm-v/modeling/cache/large-v3.pt", map_location="cpu") dims = ModelDimensions(**ckpt["dims"]) # 2) Build the module skeleton, then MATERIALIZE tensors on CPU model = Whisper(dims) model.to_empty(device="cpu") # <-- crucial when meta is involved # 3) Load weights missing, unexpected = model.load_state_dict(ckpt["model_state_dict"], strict=True) model.eval() encoder = model.encoder replace_layer_norm(encoder) return encoder def load_beats(self, model_config): beats_path = model_config.music_encoder print("Loading BEATs Model") beats_ckpt = torch.load(beats_path, map_location='cpu') beats_cfg = BEATsConfig(beats_ckpt['cfg']) beats = BEATs(beats_cfg) beats = beats.to_empty(device='cpu') # Load state dict beats.load_state_dict(beats_ckpt['model'], strict=True) return beats def forward(self, x, raw_wav=None, audio_padding_mask=None): with torch.no_grad(): self.beats_model = self.beats_model.float() speech_embeds = self.whisper_model(x) try: # 详细检查BEATs模型输入 raw_wav_float = raw_wav.float() audio_embeds, _ = self.beats_model.extract_features(raw_wav_float, padding_mask=audio_padding_mask, feature_only=True) except Exception as e: audio_embeds = torch.zeros(speech_embeds.shape[0], speech_embeds.shape[1], 1024, device=speech_embeds.device, dtype=speech_embeds.dtype) if audio_embeds.size(1) < speech_embeds.size(1): audio_embeds = F.pad(audio_embeds, (0, 0, 0, speech_embeds.size(1) - audio_embeds.size(1))) elif audio_embeds.size(1) > speech_embeds.size(1): speech_embeds = F.pad(speech_embeds, (0, 0, 0, audio_embeds.size(1) - speech_embeds.size(1))) speech_embeds = torch.cat((speech_embeds, audio_embeds), dim=-1) speech_embeds = speech_embeds.to(torch.bfloat16) return speech_embeds