|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Feature extractor class for Whisper |
|
|
""" |
|
|
import math |
|
|
from functools import partial |
|
|
from typing import List, Optional, Union |
|
|
from collections import deque |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from transformers import WhisperFeatureExtractor |
|
|
from transformers.audio_utils import mel_filter_bank |
|
|
from transformers.configuration_utils import PretrainedConfig |
|
|
from transformers.feature_extraction_utils import BatchFeature |
|
|
from transformers.utils import TensorType, logging |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
class ExtractorIterator: |
|
|
def __init__( |
|
|
self, |
|
|
data, |
|
|
batch_size=8, |
|
|
chunk_length=30, |
|
|
overlap_seconds=10, |
|
|
overlap_side="both", |
|
|
sampling_rate=16000, |
|
|
encode_func = None, |
|
|
) -> None: |
|
|
self.data = data |
|
|
self.batch_size = batch_size |
|
|
self.chunk_length = chunk_length |
|
|
self.overlap_seconds = overlap_seconds |
|
|
self.overlap_side = overlap_side |
|
|
self.sampling_rate = sampling_rate |
|
|
|
|
|
|
|
|
self.chunk_size = int(self.chunk_length * self.sampling_rate) |
|
|
self.overlap_size = int(self.overlap_seconds * self.sampling_rate) |
|
|
self.duration_size = self.chunk_size - self.overlap_size |
|
|
assert ( |
|
|
(overlap_side == "right") or (self.overlap_size % 2 == 0) |
|
|
), '`overlap_seconds` must be divisible by 2 when `overlap_side` is "both".' |
|
|
|
|
|
|
|
|
|
|
|
assert callable(encode_func) |
|
|
self.encode_func = encode_func |
|
|
|
|
|
def __iter__(self): |
|
|
""" |
|
|
返回一个生成器,该生成器负责处理所有批处理逻辑。 |
|
|
这是最 Pythonic 的实现方式。 |
|
|
""" |
|
|
|
|
|
batch_num = 0 |
|
|
|
|
|
|
|
|
wav_tensor = torch.zeros(self.batch_size, 1, self.chunk_size) |
|
|
input_lengths = deque(maxlen=self.batch_size) |
|
|
input_seq_no = torch.zeros(self.batch_size, dtype=torch.long) |
|
|
|
|
|
right_boundary = self.get_right_boundary() |
|
|
|
|
|
for i, sample in enumerate(self.data): |
|
|
sample_chunks, sample_lengths, sample_seq_no = self.chunk_and_pad_view(sample, i) |
|
|
|
|
|
processed_in_sample = 0 |
|
|
while processed_in_sample < len(sample_chunks): |
|
|
space_in_batch = self.batch_size - batch_num |
|
|
chunks_to_add = min(space_in_batch, len(sample_chunks) - processed_in_sample) |
|
|
|
|
|
|
|
|
start_idx_sample = processed_in_sample |
|
|
end_idx_sample = processed_in_sample + chunks_to_add |
|
|
start_idx_batch = batch_num |
|
|
end_idx_batch = batch_num + chunks_to_add |
|
|
|
|
|
|
|
|
wav_tensor[start_idx_batch:end_idx_batch] = sample_chunks[start_idx_sample:end_idx_sample] |
|
|
input_lengths.extend(sample_lengths[start_idx_sample:end_idx_sample]) |
|
|
input_seq_no[start_idx_batch:end_idx_batch] = sample_seq_no[start_idx_sample:end_idx_sample] |
|
|
|
|
|
|
|
|
batch_num += chunks_to_add |
|
|
processed_in_sample += chunks_to_add |
|
|
|
|
|
|
|
|
if batch_num == self.batch_size: |
|
|
list_x = [] |
|
|
for xi, (_, right) in enumerate(input_lengths): |
|
|
if right == right_boundary and torch.any(wav_tensor[xi, :, right:] != 0): |
|
|
list_x.append(wav_tensor[xi].reshape(-1).cpu().numpy()) |
|
|
else: |
|
|
list_x.append(wav_tensor[xi, :, :right].reshape(-1).cpu().numpy()) |
|
|
|
|
|
yield BatchFeature({ |
|
|
**self.encode_func(list_x), |
|
|
"input_lengths": input_lengths, |
|
|
"chunk_seq_no": input_seq_no.clone(), |
|
|
}) |
|
|
|
|
|
|
|
|
batch_num = 0 |
|
|
wav_tensor.zero_() |
|
|
input_lengths.clear() |
|
|
input_seq_no.zero_() |
|
|
|
|
|
|
|
|
if batch_num > 0: |
|
|
list_x = [] |
|
|
for xi in range(batch_num): |
|
|
_, right = input_lengths[xi] |
|
|
if right == right_boundary and torch.any(wav_tensor[xi, :, right:] != 0): |
|
|
list_x.append(wav_tensor[xi].reshape(-1).cpu().numpy()) |
|
|
else: |
|
|
list_x.append(wav_tensor[xi, :, :right].reshape(-1).cpu().numpy()) |
|
|
yield BatchFeature({ |
|
|
**self.encode_func(list_x), |
|
|
"input_lengths": input_lengths, |
|
|
"chunk_seq_no": input_seq_no[:batch_num].clone(), |
|
|
}) |
|
|
|
|
|
def chunk_and_pad_view(self, tensor, seq_no): |
|
|
x = tensor[0:1, :].unsqueeze(0) |
|
|
|
|
|
stride = self.duration_size |
|
|
kernel = self.chunk_size |
|
|
B, C, L = x.shape |
|
|
|
|
|
num_chunks = max(0, math.ceil((L - kernel) / stride)) + 1 |
|
|
target_len = (num_chunks - 1) * stride + kernel |
|
|
padding_size = max(0, target_len - L) |
|
|
x_padded = F.pad(x, (0, padding_size), "constant", 0) |
|
|
output_tensor = x_padded.unfold(dimension=2, size=kernel, step=stride).squeeze(0).transpose(0, 1) |
|
|
|
|
|
output_lengths = self.get_windows_boundaries(num_chunks, L) |
|
|
output_seq_no = torch.full((num_chunks,), seq_no, dtype=torch.long) |
|
|
return output_tensor, output_lengths, output_seq_no |
|
|
|
|
|
def get_left_boundary(self): |
|
|
if self.overlap_side == "right": |
|
|
return 0 |
|
|
else: |
|
|
return int(self.overlap_size / 2) |
|
|
|
|
|
def get_right_boundary(self): |
|
|
if self.overlap_side == "right": |
|
|
return self.duration_size |
|
|
else: |
|
|
return self.chunk_size - int(self.overlap_size / 2) |
|
|
|
|
|
def get_windows_boundaries(self, num_chunks, seq_len): |
|
|
left_boundary = self.get_left_boundary() |
|
|
right_boundary = self.get_right_boundary() |
|
|
|
|
|
output_lengths = [(left_boundary, right_boundary) for _ in range(num_chunks)] |
|
|
output_lengths[0] = (0, output_lengths[0][1]) |
|
|
output_lengths[-1] = (output_lengths[-1][0], seq_len - self.duration_size * (num_chunks-1)) |
|
|
return output_lengths |
|
|
|
|
|
|
|
|
class XYTokenizerFeatureExtractor(WhisperFeatureExtractor): |
|
|
def __init__( |
|
|
self, |
|
|
feature_size=80, |
|
|
sampling_rate=16000, |
|
|
hop_length=160, |
|
|
chunk_length=30, |
|
|
n_fft=400, |
|
|
n_samples=480000, |
|
|
nb_max_frames=3000, |
|
|
padding_side="right", |
|
|
padding_value=0.0, |
|
|
dither=0.0, |
|
|
return_attention_mask=False, |
|
|
max_frequency=None, |
|
|
batch_size=8, |
|
|
overlap_side="both", |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__( |
|
|
feature_size=feature_size, |
|
|
sampling_rate=sampling_rate, |
|
|
hop_length=hop_length, |
|
|
chunk_length=chunk_length, |
|
|
n_fft=n_fft, |
|
|
padding_value=padding_value, |
|
|
dither=dither, |
|
|
return_attention_mask=return_attention_mask, |
|
|
n_samples=n_samples, |
|
|
nb_max_frames=nb_max_frames, |
|
|
padding_side=padding_side, |
|
|
**kwargs, |
|
|
) |
|
|
self.max_frequency = max_frequency if max_frequency is not None else sampling_rate / 2 |
|
|
self.batch_size = batch_size |
|
|
self.mel_filters = mel_filter_bank( |
|
|
num_frequency_bins=1 + n_fft // 2, |
|
|
num_mel_filters=feature_size, |
|
|
min_frequency=0.0, |
|
|
max_frequency=self.max_frequency, |
|
|
sampling_rate=sampling_rate, |
|
|
norm="slaney", |
|
|
mel_scale="slaney", |
|
|
) |
|
|
self.overlap_side = overlap_side |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
raw_speech: Union[torch.Tensor, List[torch.Tensor]], |
|
|
truncation: bool = True, |
|
|
pad_to_multiple_of: Optional[int] = None, |
|
|
return_tensors: Optional[Union[str, TensorType]] = None, |
|
|
return_attention_mask: Optional[bool] = None, |
|
|
padding: Optional[str] = "max_length", |
|
|
max_length: Optional[int] = None, |
|
|
sampling_rate: Optional[int] = None, |
|
|
do_normalize: Optional[bool] = None, |
|
|
device: Optional[str] = "cpu", |
|
|
return_token_timestamps: Optional[bool] = None, |
|
|
overlap_seconds: int = 10, |
|
|
**kwargs, |
|
|
) -> ExtractorIterator: |
|
|
|
|
|
if not isinstance(raw_speech, list): |
|
|
raw_speech = [raw_speech] |
|
|
|
|
|
return ExtractorIterator( |
|
|
raw_speech, |
|
|
batch_size=self.batch_size if self.batch_size else len(raw_speech), |
|
|
chunk_length=self.chunk_length, |
|
|
overlap_seconds=overlap_seconds, |
|
|
overlap_side=self.overlap_side, |
|
|
sampling_rate=self.sampling_rate, |
|
|
encode_func=partial( |
|
|
super().__call__, |
|
|
truncation=truncation, |
|
|
pad_to_multiple_of=pad_to_multiple_of, |
|
|
return_tensors=return_tensors, |
|
|
return_attention_mask=return_attention_mask, |
|
|
padding=padding, |
|
|
max_length=max_length, |
|
|
sampling_rate=sampling_rate, |
|
|
do_normalize=do_normalize, |
|
|
device=device, |
|
|
return_token_timestamps=return_token_timestamps, |
|
|
**kwargs, |
|
|
) |
|
|
) |
|
|
|