Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace multiprocessing.dummy.Pool() with concurrent.futures.ThreadPoolExecutor() so whisper_s2t instance can run separately with multiprocessing.Process() #74

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Next Next commit
multiprocessing.dummy.Pool() replaced with safer concurrent.futures.T…
…hreadPoolExecutor().

I tried loading whisper model and running transcription in a separate multiprocessing.Process(), but failed.
The process just froze unresponsive when loading audio to memory with multiprocessing.dummy.Pool(). Replaced with
concurrent.futures.ThreadPoolExecutor() and it works now!
  • Loading branch information
kirillsaidov committed Oct 17, 2024
commit 313b7758a346d34a585f34e798849c5e0b130d06
36 changes: 32 additions & 4 deletions whisper_s2t/audio.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,8 @@
import torch.nn as nn
import torch.nn.functional as F

from multiprocessing.dummy import Pool
import concurrent
# from multiprocessing.dummy import Pool

from . import BASE_PATH
from .configs import *
@@ -66,9 +67,36 @@ def load_audio(input_file, sr=16000, return_duration=False):
return audio_signal


THREAD_POOL_AUDIO_LOADER = Pool(2)
def audio_batch_generator(audio_files):
return THREAD_POOL_AUDIO_LOADER.imap(load_audio, audio_files)
# THREAD_POOL_AUDIO_LOADER = Pool(2)
# def audio_batch_generator(audio_files):
# return THREAD_POOL_AUDIO_LOADER.imap(load_audio, audio_files)


def audio_batch_generator(audio_files: list, parallel: bool = True, max_workers: int = 2):
"""
Generate batches of loaded audio files, with option for parallel or sequential loading.

Args:
audio_files (list): list of paths to audio files
parallel (bool, optional, default=True): tries parallel loading if True else uses sequential loading
max_workers (int, optional, default=2): maximum number of parallel workers (only used if parallel=True)

Returns:
Iterator of loaded audio data
"""
# try parallel loading with ThreadPoolExecutor (safer than multiprocessing.dummy.Pool)
if parallel:
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
try:
yield from executor.map(load_audio, audio_files)
return # if parallel loading succeeded, we are done
except Exception as e:
print(f'Parallel audio loading failed: {str(e)}. Fall back to sequential loading...')
parallel = False

# sequential loading (fallback)
for audio_file in audio_files:
yield load_audio(audio_file)


def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):