from model import WavLMForEndpointing import torchaudio import transformers import numpy as np from safetensors import safe_open import torch MODEL_NAME = 'microsoft/wavlm-base-plus' processor = transformers.AutoFeatureExtractor.from_pretrained( MODEL_NAME ) config = transformers.AutoConfig.from_pretrained(MODEL_NAME) model = WavLMForEndpointing(config) checkpoint_path = "/home/nikita/wavlm-endpointing-model/checkpoint-29000/model.safetensors" with safe_open(checkpoint_path, framework="pt", device="cpu") as f: state_dict = {key: f.get_tensor(key) for key in f.keys()} model.load_state_dict(state_dict) model.eval() while True: print('1234') audio_path = str(input()) waveform, sample_rate = torchaudio.load(audio_path) if sample_rate != 16000: resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) waveform = resampler(waveform) if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) inputs = processor( waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt", padding=False, truncation=False ) with torch.no_grad(): result = model(**inputs) print(result)