-
Notifications
You must be signed in to change notification settings - Fork 61
/
Copy pathwhisper_wrapper.py
28 lines (23 loc) · 1.1 KB
/
whisper_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
from transformers import pipeline
from ModelInterfaces import IASRModel
from typing import Union
import numpy as np
class WhisperASRModel(IASRModel):
def __init__(self, model_name="openai/whisper-base"):
self.asr = pipeline("automatic-speech-recognition", model=model_name, return_timestamps="word")
self._transcript = ""
self._word_locations = []
self.sample_rate = 16000
def processAudio(self, audio:Union[np.ndarray, torch.Tensor]):
# 'audio' can be a path to a file or a numpy array of audio samples.
if isinstance(audio, torch.Tensor):
audio = audio.detach().cpu().numpy()
result = self.asr(audio[0])
self._transcript = result["text"]
self._word_locations = [{"word":word_info["text"], "start_ts":word_info["timestamp"][0]*self.sample_rate,
"end_ts":word_info["timestamp"][1]*self.sample_rate} for word_info in result["chunks"]]
def getTranscript(self) -> str:
return self._transcript
def getWordLocations(self) -> list:
return self._word_locations