| | import dataclasses
|
| | import os
|
| | import os.path
|
| | import re
|
| |
|
| | from datasets import load_dataset
|
| | from datasets import Audio
|
| | import jiwer
|
| | import torch
|
| | from transformers import AutoProcessor, Wav2Vec2ForCTC
|
| | from transformers.models.wav2vec2.processing_wav2vec2 import Wav2Vec2Processor
|
| |
|
| | MODEL = "xekri/wav2vec2-common_voice_13_0-eo-10_1"
|
| | DATA = "validation[:10]"
|
| |
|
| | chars_to_ignore_regex = "[-!\"'(),.:;=?_`¨«¸»ʼ‑–—‘’“”„…‹›♫?]"
|
| | chars_to_substitute = {
|
| | "przy": "pŝe",
|
| | "byn": "bin",
|
| | "cx": "ĉ",
|
| | "sx": "ŝ",
|
| | "fi": "fi",
|
| | "fl": "fl",
|
| | "ǔ": "ŭ",
|
| | "ñ": "nj",
|
| | "á": "a",
|
| | "é": "e",
|
| | "ü": "ŭ",
|
| | "y": "j",
|
| | "qu": "ku",
|
| | }
|
| |
|
| |
|
| | def remove_special_characters(text: str) -> str:
|
| | text = re.sub(chars_to_ignore_regex, "", text)
|
| | text = text.lower()
|
| | return text
|
| |
|
| |
|
| | def substitute_characters(text: str) -> str:
|
| | for k, v in chars_to_substitute.items():
|
| | text.replace(k, v)
|
| | text = text.lower()
|
| | return text
|
| |
|
| |
|
| | @dataclasses.dataclass
|
| | class EvalResult:
|
| | filename: str
|
| | cer: float
|
| | loss: float
|
| | actual: str
|
| | predicted: str
|
| |
|
| | def print(self) -> None:
|
| | print(f"FILE {self.filename}")
|
| | print(f"CERR {self.cer}")
|
| | print(f"LOSS {self.loss}")
|
| | print(f"ACTU {self.actual}")
|
| | print(f"PRED {self.predicted}")
|
| |
|
| |
|
| | def evaluate(processor: Wav2Vec2Processor, model, example) -> EvalResult:
|
| | """Evaluates a single example."""
|
| | audio_file = example["path"]
|
| | d, n = os.path.split(audio_file)
|
| | f = os.listdir(d)[0]
|
| | audio_file = os.path.join(d, f, n)
|
| |
|
| | inputs = processor(
|
| | audio=example["audio"]["array"], sampling_rate=16000, return_tensors="pt"
|
| | )
|
| |
|
| | with torch.no_grad():
|
| | logits = model(**inputs).logits
|
| | predicted_ids = logits.argmax(dim=-1)
|
| | predict = processor.batch_decode(predicted_ids)[0]
|
| |
|
| | actual = example["sentence"]
|
| | actual = substitute_characters(remove_special_characters(actual))
|
| | inputs["labels"] = processor(text=actual, return_tensors="pt").input_ids
|
| | loss = model(**inputs).loss
|
| | cer = jiwer.cer(actual, predict)
|
| |
|
| | return EvalResult(os.path.basename(audio_file), cer, loss, actual, predict)
|
| |
|
| |
|
| | def run() -> None:
|
| | cv13 = load_dataset(
|
| | "mozilla-foundation/common_voice_13_0",
|
| | "eo",
|
| | split=DATA,
|
| | )
|
| | cv13 = cv13.cast_column("audio", Audio(sampling_rate=16000))
|
| |
|
| | processor: Wav2Vec2Processor = AutoProcessor.from_pretrained(MODEL)
|
| | model = Wav2Vec2ForCTC.from_pretrained(MODEL)
|
| |
|
| | print("| Actual<br>Predicted | CER |")
|
| | print("|:--------------------|:----|")
|
| |
|
| | for i, example in enumerate(cv13):
|
| | results = evaluate(processor, model, example)
|
| | print(f"| `{results.actual}`<br>`{results.predicted}` | {results.cer} |")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | run() |