diff --git a/wenet/cli/hub.py b/wenet/cli/hub.py index b8ca91ad5..171e334fa 100644 --- a/wenet/cli/hub.py +++ b/wenet/cli/hub.py @@ -13,12 +13,12 @@ # limitations under the License. import os -import requests import sys import tarfile from pathlib import Path from urllib.request import urlretrieve +import requests import tqdm @@ -77,7 +77,9 @@ class Hub(object): # gigaspeech "english": "gigaspeech_u2pp_conformer_libtorch.tar.gz", # paraformer - "paraformer": "paraformer.tar.gz" + "paraformer": "paraformer.tar.gz", + # punc + "punc": "punc.tar.gz" } def __init__(self) -> None: diff --git a/wenet/cli/punc_model.py b/wenet/cli/punc_model.py new file mode 100644 index 000000000..cab6ca73a --- /dev/null +++ b/wenet/cli/punc_model.py @@ -0,0 +1,115 @@ +import os +from typing import List + +import jieba +import torch +from wenet.cli.hub import Hub +from wenet.paraformer.search import _isAllAlpha +from wenet.text.char_tokenizer import CharTokenizer + + +class PuncModel: + + def __init__(self, model_dir: str) -> None: + self.model_dir = model_dir + model_path = os.path.join(model_dir, 'final.zip') + units_path = os.path.join(model_dir, 'units.txt') + + self.model = torch.jit.load(model_path) + self.tokenizer = CharTokenizer(units_path) + self.device = torch.device("cpu") + self.use_jieba = False + + self.punc_table = ['', '', ',', '。', '?', '、'] + + def split_words(self, text: str): + if not self.use_jieba: + self.use_jieba = True + import logging + + # Disable jieba's logger + logging.getLogger('jieba').disabled = True + jieba.load_userdict(os.path.join(self.model_dir, 'jieba_usr_dict')) + + result_list = [] + tokens = text.split() + current_language = None + buffer = [] + + for token in tokens: + is_english = token.isascii() + if is_english: + language = "English" + else: + language = "Chinese" + + if current_language and language != current_language: + if current_language == "Chinese": + result_list.extend(jieba.cut(''.join(buffer), HMM=False)) + else: + result_list.extend(buffer) + buffer = [] + + buffer.append(token) + current_language = language + + if buffer: + if current_language == "Chinese": + result_list.extend(jieba.cut(''.join(buffer), HMM=False)) + else: + result_list.extend(buffer) + + return result_list + + def add_punc_batch(self, texts: List[str]): + batch_text_words = [] + batch_text_ids = [] + batch_text_lens = [] + + for text in texts: + words = self.split_words(text) + ids = self.tokenizer.tokens2ids(words) + batch_text_words.append(words) + batch_text_ids.append(ids) + batch_text_lens.append(len(ids)) + + texts_tensor = torch.tensor(batch_text_ids, + device=self.device, + dtype=torch.int64) + texts_lens_tensor = torch.tensor(batch_text_lens, + device=self.device, + dtype=torch.int64) + + log_probs, _ = self.model(texts_tensor, texts_lens_tensor) + result = [] + outs = log_probs.argmax(-1).cpu().numpy() + for i, out in enumerate(outs): + punc_id = out[:batch_text_lens[i]] + sentence = '' + for j, word in enumerate(batch_text_words[i]): + if _isAllAlpha(word): + word = '▁' + word + word += self.punc_table[punc_id[j]] + sentence += word + result.append(sentence.replace('▁', ' ')) + return result + + def __call__(self, text: str): + if text != '': + r = self.add_punc_batch([text])[0] + return r + return '' + + +def load_model(model_dir: str = None, + gpu: int = -1, + device: str = "cpu") -> PuncModel: + if model_dir is None: + model_dir = Hub.get_model_by_lang('punc') + if gpu != -1: + # remain the original usage of gpu + device = "cuda" + punc = PuncModel(model_dir) + punc.device = torch.device(device) + punc.model.to(device) + return punc diff --git a/wenet/cli/transcribe.py b/wenet/cli/transcribe.py index 28bf27919..8d65447c0 100644 --- a/wenet/cli/transcribe.py +++ b/wenet/cli/transcribe.py @@ -14,8 +14,9 @@ import argparse -from wenet.cli.paraformer_model import load_model as load_paraformer from wenet.cli.model import load_model +from wenet.cli.paraformer_model import load_model as load_paraformer +from wenet.cli.punc_model import load_model as load_punc_model def get_args(): @@ -64,6 +65,13 @@ def get_args(): type=float, default=6.0, help='context score') + parser.add_argument('--punc', action='store_true', help='context score') + + parser.add_argument('-pm', + '--punc_model_dir', + default=None, + help='specify your own punc model dir') + args = parser.parse_args() return args @@ -76,10 +84,17 @@ def main(): else: model = load_model(args.language, args.model_dir, args.gpu, args.beam, args.context_path, args.context_score, args.device) + punc_model = None + if args.punc: + punc_model = load_punc_model(args.punc_model_dir, args.gpu, + args.device) if args.align: result = model.align(args.audio_file, args.label) else: result = model.transcribe(args.audio_file, args.show_tokens_info) + if args.punc: + assert punc_model is not None + result['text_with_punc'] = punc_model(result['text']) print(result) diff --git a/wenet/paraformer/layers.py b/wenet/paraformer/layers.py index d17280d8a..4c4a37390 100644 --- a/wenet/paraformer/layers.py +++ b/wenet/paraformer/layers.py @@ -3,18 +3,17 @@ import math from typing import Optional, Tuple -import torch +import torch import torch.utils.checkpoint as ckpt - from wenet.paraformer.attention import (DummyMultiHeadSANM, MultiHeadAttentionCross, MultiHeadedAttentionSANM) from wenet.paraformer.embedding import ParaformerPositinoalEncoding from wenet.paraformer.subsampling import IdentitySubsampling -from wenet.transformer.encoder import BaseEncoder from wenet.transformer.decoder import TransformerDecoder from wenet.transformer.decoder_layer import DecoderLayer +from wenet.transformer.encoder import BaseEncoder from wenet.transformer.encoder_layer import TransformerEncoderLayer from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward from wenet.utils.mask import make_non_pad_mask @@ -190,7 +189,7 @@ def __init__( num_blocks: int = 6, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, - attention_dropout_rate: float = 0, + attention_dropout_rate: float = 0.0, input_layer: str = "conv2d", pos_enc_layer_type: str = "abs_pos", normalize_before: bool = True, @@ -389,8 +388,8 @@ def __init__( num_blocks: int = 6, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, - self_attention_dropout_rate: float = 0, - src_attention_dropout_rate: float = 0, + self_attention_dropout_rate: float = 0.0, + src_attention_dropout_rate: float = 0.0, input_layer: str = "embed", use_output_layer: bool = True, normalize_before: bool = True,