Skip to content

Commit

Permalink
[cli] support on-the-fly training by loading pt model as nn.Module
Browse files Browse the repository at this point in the history
  • Loading branch information
robin1001 committed Jan 7, 2025
1 parent 5e6d4e3 commit e3429bd
Show file tree
Hide file tree
Showing 14 changed files with 38 additions and 4 deletions.
Empty file added wenet/LLM/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion wenet/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from wenet.cli.model import load_model # noqa
from wenet.cli.model import load_model, load_model_pt # noqa
Empty file added wenet/bin/__init__.py
Empty file.
33 changes: 33 additions & 0 deletions wenet/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
import yaml

from wenet.cli.hub import Hub
from wenet.utils.ctc_utils import (force_align, gen_ctc_peak_time,
gen_timestamps_from_peak)
from wenet.utils.file_utils import read_symbol_table
from wenet.utils.init_model import init_model
from wenet.transformer.search import (attention_rescoring,
ctc_prefix_beam_search, DecodeResult)
from wenet.utils.context_graph import ContextGraph
Expand Down Expand Up @@ -174,3 +176,34 @@ def load_model(language: str = None,
model.device = torch.device(device)
model.model.to(device)
return model

# Load the pytorch pt model which contains all the details compared with jit.
# And we can use the pt model as a third party pytorch nn.Module for training
def load_model_pt(model_dir):
""" There are the followi files in in `model_dir`
* final.pt, required
* train.yaml, required
* units.txt, required
* global_cmvn, optional
"""
# Check required files
required_files = ['train.yaml', 'final.pt', 'units.txt']
for file in required_files:
file_path = os.path.join(model_dir, file)
if not os.path.exists(file_path):
raise FileNotFoundError(
f"Required file '{file}' not found in '{model_dir}'")
# Read config and override some config
config_file = os.path.join(model_dir, 'train.yaml')
with open(config_file, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
token_file = os.path.join(model_dir, 'units.txt')
configs['tokenizer_conf']['symbol_table_path'] = token_file
cmvn_file = os.path.join(model_dir, 'global_cmvn')
if os.path.exists(cmvn_file):
configs['cmvn_conf']['cmvn_file'] = cmvn_file
# Read model
pt_file = os.path.join(model_dir, 'final.pt')
args = {'checkpoint': pt_file}
model, configs = init_model(args, configs)
return model
Empty file added wenet/ctl_model/__init__.py
Empty file.
Empty file.
Empty file.
Empty file added wenet/finetune/__init__.py
Empty file.
Empty file added wenet/ssl/__init__.py
Empty file.
Empty file added wenet/ssl/bestrq/__init__.py
Empty file.
Empty file added wenet/ssl/w2vbert/__init__.py
Empty file.
Empty file added wenet/ssl/wav2vec2/__init__.py
Empty file.
Empty file.
7 changes: 4 additions & 3 deletions wenet/utils/init_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,13 @@ def init_model(args, configs):
if hasattr(args, 'lora_ckpt_path') and args.lora_ckpt_path:
load_checkpoint(model, args.lora_ckpt_path)

print(configs)
# Trye to tie some weights
if hasattr(model, 'tie_or_clone_weights'):
if not hasattr(args, 'jit'):
args.jit = True # i.e. export onnx/jit/ipex
model.tie_or_clone_weights(args.jit)
jit = True # i.e. export onnx/jit/ipex
else:
jit = False
model.tie_or_clone_weights(jit)

if hasattr(args, 'only_optimize_lora') and args.only_optimize_lora:
mark_only_lora_as_trainable(model, bias='lora_only')
Expand Down

0 comments on commit e3429bd

Please sign in to comment.