import torch from torch import nn import torch.nn.functional as F import torch.nn as nn from transformers import AutoModel import torch from torch import nn import torch.nn.functional as F class BCEWithLogitsLossLS(nn.Module): def __init__(self, label_smoothing=0.1, pos_weight=None, reduction='mean'): super(BCEWithLogitsLossLS, self).__init__() assert 0 <= label_smoothing < 1, "label_smoothing value must be between 0 and 1." self.label_smoothing = label_smoothing self.reduction = reduction self.bce_with_logits = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction=reduction) def forward(self, input, target): if self.label_smoothing > 0: positive_smoothed_labels = 1.0 - self.label_smoothing negative_smoothed_labels = self.label_smoothing target = target * positive_smoothed_labels + \ (1 - target) * negative_smoothed_labels loss = self.bce_with_logits(input, target) return loss class WavLMForEndpointing(nn.Module): def __init__(self, config, n_trainable_layers=6): super().__init__() self.wavlm = AutoModel.from_pretrained('microsoft/wavlm-base-plus', config=config) self.config = config self.n_trainable_layers = n_trainable_layers for param in self.wavlm.parameters(): param.requires_grad = False if self.n_trainable_layers > 0: for i in range(self.n_trainable_layers): for param in self.wavlm.encoder.layers[-(i+1)].parameters(): param.requires_grad = True self.pool_attention = nn.Sequential( nn.Linear(config.hidden_size, 256), nn.Tanh(), nn.Linear(256, 1) ) self.classifier = nn.Sequential( nn.Linear(config.hidden_size, 256), nn.LayerNorm(256), nn.GELU(), nn.Dropout(0.1), nn.Linear(256, 64), nn.LayerNorm(64), nn.GELU(), nn.Linear(64, 1) ) for module in self.classifier: if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=0.1) if module.bias is not None: module.bias.data.zero_() for module in self.pool_attention: if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=0.1) if module.bias is not None: module.bias.data.zero_() def attention_pool(self, hidden_states, attention_mask): attention_weights = self.pool_attention(hidden_states) if attention_mask is None: raise ValueError("attention_mask must be provided for attention pooling") attention_weights = attention_weights + ( (1.0 - attention_mask.unsqueeze(-1).to(attention_weights.dtype)) * -1e9 ) attention_weights = F.softmax(attention_weights, dim=1) # Apply attention to hidden states weighted_sum = torch.sum(hidden_states * attention_weights, dim=1) return weighted_sum def forward(self, input_values, attention_mask=None, labels=None): outputs = self.wavlm(input_values, attention_mask=attention_mask) hidden_states = outputs[0] if attention_mask is not None: input_length = attention_mask.size(1) hidden_length = hidden_states.size(1) ratio = input_length / hidden_length indices = (torch.arange(hidden_length, device=attention_mask.device) * ratio).long() attention_mask = attention_mask[:, indices] attention_mask = attention_mask.bool() else: attention_mask = None pooled = self.attention_pool(hidden_states, attention_mask) logits = self.classifier(pooled) if torch.isnan(logits).any(): raise ValueError("NaN values detected in logits") if labels is not None: pos_weight = ((labels == 0).sum() / (labels == 1).sum()).clamp(min=0.1, max=10.0) loss_fct = BCEWithLogitsLossLS(pos_weight=pos_weight) labels = labels.float() loss = loss_fct(logits.view(-1), labels.view(-1)) l2_lambda = 0.01 l2_reg = torch.tensor(0., device=logits.device) for param in self.classifier.parameters(): l2_reg += torch.norm(param) loss += l2_lambda * l2_reg probs = torch.sigmoid(logits.detach()) return {"loss": loss, "logits": probs} probs = torch.sigmoid(logits) return {"logits": probs}