Spaces:
Running
Running
| /** | |
| * Punctuation and Capitalization using ONNX | |
| * - English: Full punctuation + capitalization (1-800-BAD-CODE model) | |
| * - Other languages (DE, FR, IT, NL, ES, PT): Punctuation only (oliverguhr multilingual model) | |
| */ | |
| // English model (punctuation + capitalization) | |
| let pcsSession = null; | |
| let pcsVocab = null; | |
| let pcsVocabReverse = null; | |
| // Multilingual model (punctuation only) | |
| let multilingualSession = null; | |
| let multilingualTokenizer = null; | |
| const PCS_CONFIG = { | |
| preLabels: ["<NULL>", "¿"], | |
| postLabels: ["<NULL>", "<ACRONYM>", ".", ",", "?"], | |
| unkId: 0, | |
| bosId: 1, | |
| eosId: 2, | |
| padId: 3, | |
| }; | |
| // Multilingual model label mapping | |
| const MULTILINGUAL_LABELS = { | |
| 0: "", // No punctuation | |
| 1: ".", // Period | |
| 2: ",", // Comma | |
| 3: "?", // Question mark | |
| 4: "-", // Hyphen | |
| 5: ":", // Colon | |
| }; | |
| // Languages supported by multilingual model | |
| const MULTILINGUAL_LANGS = ['de', 'fr', 'it', 'nl', 'es', 'pt']; | |
| // Load the English punctuator model and vocab | |
| async function cachedFetch(url) { | |
| const cache = await caches.open('granite-speech-local-models'); | |
| const cached = await cache.match(url); | |
| if (cached) return cached; | |
| const response = await fetch(url); | |
| if (response.ok) await cache.put(url, response.clone()); | |
| return response; | |
| } | |
| async function loadEnglishPunctuator() { | |
| if (pcsSession) return; | |
| console.log('Loading English punctuator model...'); | |
| // Load vocab | |
| const vocabResponse = await cachedFetch('./pcs_vocab.json'); | |
| const vocabData = await vocabResponse.json(); | |
| pcsVocab = vocabData.vocab; | |
| // Create reverse vocab (id -> piece) | |
| pcsVocabReverse = {}; | |
| for (const [piece, id] of Object.entries(pcsVocab)) { | |
| pcsVocabReverse[id] = piece; | |
| } | |
| // Load ONNX model | |
| const modelResponse = await cachedFetch('./punct_cap_seg_en.onnx'); | |
| const buffer = await modelResponse.arrayBuffer(); | |
| pcsSession = await ort.InferenceSession.create(buffer, { | |
| executionProviders: ['wasm'], | |
| }); | |
| console.log('English punctuator model loaded'); | |
| } | |
| // Load the multilingual punctuator model | |
| async function loadMultilingualPunctuator() { | |
| if (multilingualSession) return; | |
| console.log('Loading multilingual punctuator model...'); | |
| // Load tokenizer from transformers.js | |
| const { AutoTokenizer } = await import('https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.4.2'); | |
| multilingualTokenizer = await AutoTokenizer.from_pretrained('oliverguhr/fullstop-punctuation-multilingual-base'); | |
| // Load ONNX model | |
| const modelResponse = await cachedFetch('./punct_multilingual_q8.onnx'); | |
| const buffer = await modelResponse.arrayBuffer(); | |
| multilingualSession = await ort.InferenceSession.create(buffer, { | |
| executionProviders: ['wasm'], | |
| }); | |
| console.log('Multilingual punctuator model loaded'); | |
| } | |
| // Simple Unigram tokenizer for English model (greedy longest match) | |
| function tokenizeEnglish(text) { | |
| const normalized = text.toLowerCase().replace(/ /g, '▁'); | |
| const tokens = []; | |
| let i = 0; | |
| // Add BOS | |
| tokens.push(PCS_CONFIG.bosId); | |
| // Prepend ▁ for first word | |
| let remaining = '▁' + normalized; | |
| while (remaining.length > 0) { | |
| let found = false; | |
| // Try longest match first | |
| for (let len = Math.min(remaining.length, 20); len > 0; len--) { | |
| const piece = remaining.substring(0, len); | |
| if (pcsVocab[piece] !== undefined) { | |
| tokens.push(pcsVocab[piece]); | |
| remaining = remaining.substring(len); | |
| found = true; | |
| break; | |
| } | |
| } | |
| if (!found) { | |
| // Unknown character, use UNK and skip | |
| tokens.push(PCS_CONFIG.unkId); | |
| remaining = remaining.substring(1); | |
| } | |
| } | |
| // Add EOS | |
| tokens.push(PCS_CONFIG.eosId); | |
| return tokens; | |
| } | |
| // Apply punctuation and capitalization for English | |
| async function applyEnglishPunctuation(text) { | |
| await loadEnglishPunctuator(); | |
| // Tokenize | |
| const tokenIds = tokenizeEnglish(text); | |
| // Run inference | |
| const inputTensor = new ort.Tensor('int64', BigInt64Array.from(tokenIds.map(BigInt)), [1, tokenIds.length]); | |
| const outputs = await pcsSession.run({ input_ids: inputTensor }); | |
| const prePreds = outputs.pre_preds.data; | |
| const postPreds = outputs.post_preds.data; | |
| const capPreds = outputs.cap_preds.data; | |
| const segPreds = outputs.seg_preds.data; | |
| // Decode: skip BOS (index 0) and EOS (last index) | |
| const numTokens = tokenIds.length - 2; | |
| const result = []; | |
| let currentSentence = []; | |
| for (let i = 0; i < numTokens; i++) { | |
| const tokenId = tokenIds[i + 1]; | |
| const token = pcsVocabReverse[tokenId] || ''; | |
| const outputIdx = i + 1; | |
| // Handle word boundary | |
| if (token.startsWith('▁') && currentSentence.length > 0) { | |
| currentSentence.push(' '); | |
| } | |
| // Process each character in token | |
| const charStart = token.startsWith('▁') ? 1 : 0; | |
| for (let j = charStart; j < token.length; j++) { | |
| let char = token[j]; | |
| // Pre-punctuation (e.g., inverted question mark) | |
| if (j === charStart && prePreds[outputIdx] === 1) { | |
| currentSentence.push(PCS_CONFIG.preLabels[1]); | |
| } | |
| // Capitalization - capPreds is [batch, seq, 16] | |
| const capOffset = outputIdx * 16 + j; | |
| if (capPreds[capOffset]) { | |
| char = char.toUpperCase(); | |
| } | |
| currentSentence.push(char); | |
| // Post-punctuation | |
| const postLabel = postPreds[outputIdx]; | |
| if (postLabel === 1) { // ACRONYM | |
| currentSentence.push('.'); | |
| } else if (j === token.length - 1 && postLabel > 1) { | |
| currentSentence.push(PCS_CONFIG.postLabels[postLabel]); | |
| } | |
| } | |
| // Sentence boundary | |
| if (segPreds[outputIdx]) { | |
| result.push(currentSentence.join('')); | |
| currentSentence = []; | |
| } | |
| } | |
| if (currentSentence.length > 0) { | |
| result.push(currentSentence.join('')); | |
| } | |
| return result.join(' '); | |
| } | |
| // Apply punctuation only for other languages (multilingual model) | |
| async function applyMultilingualPunctuation(text) { | |
| await loadMultilingualPunctuator(); | |
| // Tokenize using transformers.js tokenizer | |
| const encoded = await multilingualTokenizer(text, { | |
| return_tensors: false, | |
| padding: false, | |
| truncation: true, | |
| max_length: 512, | |
| }); | |
| const inputIds = encoded.input_ids; | |
| const attentionMask = encoded.attention_mask; | |
| // Run inference | |
| const inputIdsTensor = new ort.Tensor('int64', BigInt64Array.from(inputIds.map(BigInt)), [1, inputIds.length]); | |
| const attentionMaskTensor = new ort.Tensor('int64', BigInt64Array.from(attentionMask.map(BigInt)), [1, attentionMask.length]); | |
| const outputs = await multilingualSession.run({ | |
| input_ids: inputIdsTensor, | |
| attention_mask: attentionMaskTensor, | |
| }); | |
| const logits = outputs.logits.data; | |
| const numLabels = 6; | |
| // Get predictions (argmax over logits) | |
| const predictions = []; | |
| for (let i = 0; i < inputIds.length; i++) { | |
| let maxIdx = 0; | |
| let maxVal = logits[i * numLabels]; | |
| for (let j = 1; j < numLabels; j++) { | |
| if (logits[i * numLabels + j] > maxVal) { | |
| maxVal = logits[i * numLabels + j]; | |
| maxIdx = j; | |
| } | |
| } | |
| predictions.push(maxIdx); | |
| } | |
| // Decode tokens back to text with punctuation | |
| const tokens = multilingualTokenizer.model.convert_ids_to_tokens(inputIds); | |
| const result = []; | |
| for (let i = 0; i < tokens.length; i++) { | |
| const token = tokens[i]; | |
| // Skip special tokens | |
| if (token === '<s>' || token === '</s>' || token === '<pad>') { | |
| continue; | |
| } | |
| // Handle subword tokens (▁ prefix indicates start of new word) | |
| if (token.startsWith('▁')) { | |
| if (result.length > 0) { | |
| result.push(' '); | |
| } | |
| result.push(token.substring(1)); | |
| } else { | |
| result.push(token); | |
| } | |
| // Add punctuation after token | |
| const punct = MULTILINGUAL_LABELS[predictions[i]]; | |
| if (punct) { | |
| result.push(punct); | |
| } | |
| } | |
| return result.join(''); | |
| } | |
| // Main entry point - routes to appropriate model based on language | |
| async function applyPunctuation(text, lang = null) { | |
| if (!text || text.trim().length === 0) return text; | |
| // If language specified and supported by multilingual model, use it | |
| if (lang && MULTILINGUAL_LANGS.includes(lang)) { | |
| try { | |
| return await applyMultilingualPunctuation(text); | |
| } catch (error) { | |
| console.warn('Multilingual punctuation failed, returning original:', error); | |
| return text; | |
| } | |
| } | |
| // Default to English model | |
| try { | |
| return await applyEnglishPunctuation(text); | |
| } catch (error) { | |
| console.warn('English punctuation failed, returning original:', error); | |
| return text; | |
| } | |
| } | |
| // Preload English model (called during init) | |
| async function loadPunctuator() { | |
| await loadEnglishPunctuator(); | |
| } | |
| // Export for use in app.js | |
| window.applyPunctuation = applyPunctuation; | |
| window.loadPunctuator = loadPunctuator; | |
| window.MULTILINGUAL_PUNCT_LANGS = MULTILINGUAL_LANGS; | |