Skip to content

Commit 8231aad

Browse files
committed
standalone 2024 setup add LSTM lm pipeline
1 parent f87d62a commit 8231aad

File tree

7 files changed

+386
-0
lines changed

7 files changed

+386
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
from sisyphus import tk
2+
from sisyphus.delayed_ops import DelayedFormat
3+
4+
from dataclasses import dataclass
5+
import os
6+
from typing import Any, Dict, Optional
7+
8+
from i6_core.text.label.subword_nmt.apply import ApplyBPEToTextJob
9+
from i6_core.corpus.convert import CorpusToTxtJob
10+
from i6_core.text.processing import ConcatenateJob
11+
from i6_core.returnn.config import CodeWrapper
12+
13+
from i6_experiments.common.setups.returnn.datasets import MetaDataset, ControlDataset, Dataset
14+
from i6_experiments.common.setups.returnn.datastreams.base import Datastream
15+
from i6_experiments.common.setups.returnn.datastreams.vocabulary import BpeDatastream
16+
from i6_experiments.common.helpers.text_labels.subword_nmt_bpe import get_returnn_subword_nmt
17+
18+
from i6_experiments.common.datasets.librispeech import get_bliss_corpus_dict
19+
from i6_experiments.common.datasets.librispeech.vocab import get_subword_nmt_bpe_v2
20+
from i6_experiments.common.datasets.librispeech.language_model import get_librispeech_normalized_lm_data
21+
22+
23+
24+
SOURCE_DATASTREAM_KEY = "data"
25+
TARGET_DATASTREAN_KEY = "delayed"
26+
27+
28+
@dataclass(frozen=True)
29+
class TrainingDatasets:
30+
train: Dataset
31+
cv: Dataset
32+
devtrain: Dataset
33+
datastreams: Dict[str, Datastream]
34+
35+
36+
class LmDataset(ControlDataset):
37+
38+
def __init__(
39+
self,
40+
*,
41+
corpus_file: tk.Path,
42+
vocab_file: tk.Path,
43+
# super parameters
44+
partition_epoch: Optional[int] = None,
45+
segment_file: Optional[tk.Path] = None,
46+
seq_ordering: Optional[str] = None,
47+
random_subset: Optional[int] = None,
48+
additional_options: Optional[Dict] = None,
49+
):
50+
super().__init__(
51+
partition_epoch=partition_epoch,
52+
segment_file=segment_file,
53+
seq_ordering=seq_ordering,
54+
random_subset=random_subset,
55+
additional_options=additional_options
56+
)
57+
58+
self.corpus_file = corpus_file
59+
self.vocab_file = vocab_file
60+
61+
def as_returnn_opts(self) -> Dict[str, Any]:
62+
d = {
63+
"class": "LmDataset",
64+
"corpus_file": CodeWrapper(DelayedFormat('lambda: cf("{}")', self.corpus_file)),
65+
"orth_symbols_map_file": self.vocab_file,
66+
"orth_replace_map_file": "",
67+
"word_based": True,
68+
"seq_end_symbol": "</s>",
69+
"auto_replace_unknown_symbol": False,
70+
"unknown_symbol": "<unk>",
71+
"add_delayed_seq_data": True,
72+
"delayed_seq_data_start_symbol": "<s>",
73+
}
74+
sd = super().as_returnn_opts()
75+
assert all([k not in sd.keys() for k in d.keys()]), (
76+
"conflicting keys in %s and %s"
77+
% (str(list(sd.keys())), str(list(d.keys()))),
78+
)
79+
d.update(sd)
80+
81+
return d
82+
83+
@dataclass()
84+
class LMDatasetSettings:
85+
train_partition_epoch: int
86+
train_seq_ordering: str
87+
88+
89+
def get_subword_repo():
90+
"""
91+
This is a for now very ugly helper to get the same subword_nmt repo
92+
as the get_subword_nmt_bpe_v2 is using
93+
:return:
94+
"""
95+
subword_nmt_repo = get_returnn_subword_nmt(
96+
commit_hash="5015a45e28a958f800ef1c50e7880c0c9ef414cf", output_prefix=""
97+
)
98+
# overwrite hash for future bugfixes, it is unlikely the logic will ever be changed
99+
subword_nmt_repo.hash_overwrite = "I6_SUBWORD_NMT_V2"
100+
return subword_nmt_repo
101+
102+
def build_lm_training_datasets(prefix, librispeech_key, bpe_size, settings: LMDatasetSettings):
103+
104+
#data_map = {SOURCE_DATASTREAM_KEY: ("lm_dataset", "data"), TARGET_DATASTREAN_KEY: ("lm_dataset", "delayed")}
105+
#def make_meta(dataset: LmDataset):
106+
# return MetaDataset(
107+
# data_map=data_map, datasets={"lm_dataset": dataset}, seq_order_control_dataset="lm_dataset"
108+
# )
109+
110+
bpe_settings = get_subword_nmt_bpe_v2(corpus_key=librispeech_key, bpe_size=bpe_size, unk_label='<unk>')
111+
ls_bliss_corpus_dict = get_bliss_corpus_dict()
112+
bpe_datastream = BpeDatastream(available_for_inference=False, bpe_settings=bpe_settings)
113+
114+
#### Training Data ####
115+
116+
lm_data = get_librispeech_normalized_lm_data()
117+
ls_train_bliss = ls_bliss_corpus_dict["train-other-960"]
118+
ls_train_text = CorpusToTxtJob(
119+
bliss_corpus=ls_train_bliss,
120+
gzip=True,
121+
).out_txt
122+
full_train_text = ConcatenateJob(
123+
text_files=[lm_data, ls_train_text],
124+
zip_out=True,
125+
).out
126+
lm_bpe_data_job = ApplyBPEToTextJob(
127+
text_file=full_train_text,
128+
bpe_codes=bpe_settings.bpe_codes,
129+
bpe_vocab=bpe_settings.bpe_count_vocab,
130+
gzip_output=True,
131+
subword_nmt_repo=get_subword_repo(),
132+
mini_task=False, # this is a large file, so run in cluster
133+
)
134+
lm_bpe_data_job.add_alias(os.path.join(prefix, "apply_bpe_to_train"))
135+
136+
#### Dev Data ####
137+
138+
dev_clean_text = CorpusToTxtJob(bliss_corpus=ls_bliss_corpus_dict["dev-clean"], gzip=True).out_txt
139+
dev_other_text = CorpusToTxtJob(bliss_corpus=ls_bliss_corpus_dict["dev-other"], gzip=True).out_txt
140+
cv_text = ConcatenateJob(
141+
text_files=[dev_clean_text, dev_other_text],
142+
zip_out=True,
143+
).out
144+
cv_bpe_data_job = ApplyBPEToTextJob(
145+
text_file=cv_text,
146+
bpe_codes=bpe_settings.bpe_codes,
147+
bpe_vocab=bpe_settings.bpe_count_vocab,
148+
gzip_output=True,
149+
subword_nmt_repo=get_subword_repo(),
150+
)
151+
152+
#### datasets ####
153+
lm_train_dataset = LmDataset(
154+
corpus_file=lm_bpe_data_job.out_bpe_text,
155+
vocab_file=bpe_settings.bpe_vocab,
156+
partition_epoch=settings.train_partition_epoch,
157+
segment_file=None,
158+
seq_ordering=settings.train_seq_ordering
159+
)
160+
161+
lm_cv_dataset = LmDataset(
162+
corpus_file=cv_bpe_data_job.out_bpe_text,
163+
vocab_file=bpe_settings.bpe_vocab,
164+
partition_epoch=1,
165+
segment_file=None,
166+
seq_ordering="sorted"
167+
)
168+
169+
lm_devtrain_dataset = LmDataset(
170+
corpus_file=lm_bpe_data_job.out_bpe_text,
171+
vocab_file=bpe_settings.bpe_vocab,
172+
partition_epoch=1,
173+
segment_file=None,
174+
seq_ordering="sorted",
175+
random_subset=3000,
176+
)
177+
178+
return TrainingDatasets(
179+
train=lm_train_dataset,
180+
cv=lm_cv_dataset,
181+
# devtrain=lm_devtrain_dataset,
182+
# TODO: Ultra hack for now
183+
devtrain=lm_cv_dataset,
184+
datastreams={"data": bpe_datastream, "delayed": bpe_datastream},
185+
)
186+

users/rossenbach/experiments/librispeech/ctc_rnnt_standalone_2024/experiments/lm_bpe/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from sisyphus import tk
2+
3+
from dataclasses import asdict
4+
from typing import cast
5+
6+
from i6_experiments.common.setups.returnn.datastreams.vocabulary import LabelDatastream
7+
8+
from ...data.bpe_lm import build_lm_training_datasets, LMDatasetSettings
9+
from ...default_tools import RETURNN_EXE, MINI_RETURNN_ROOT
10+
from ...pipeline import training
11+
12+
13+
def bpe_kazuki_lstm():
14+
prefix_name = "experiments/librispeech/ctc_rnnt_standalone_2024/kazuki_lstm/"
15+
16+
train_settings = LMDatasetSettings(
17+
train_partition_epoch=4,
18+
train_seq_ordering="laplace:.100",
19+
)
20+
21+
# build the training datasets object containing train, cv, dev-train and the extern_data dict
22+
train_data_bpe10k = build_lm_training_datasets(
23+
prefix=prefix_name,
24+
librispeech_key="train-other-960",
25+
bpe_size=10000,
26+
settings=train_settings,
27+
)
28+
label_datastream_bpe5000 = cast(LabelDatastream, train_data_bpe10k.datastreams["data"])
29+
vocab_size_without_blank = label_datastream_bpe5000.vocab_size
30+
31+
default_returnn = {
32+
"returnn_exe": RETURNN_EXE,
33+
"returnn_root": MINI_RETURNN_ROOT,
34+
}
35+
36+
from ...pytorch_networks.lm.lstm.kazuki_lstm_zijian_variant_v1_cfg import ModelConfig
37+
38+
default_init_args = {
39+
'init_args_w': {'func': 'normal', 'arg': {'mean': 0.0, 'std': 0.1}},
40+
'init_args_b': {'func': 'normal', 'arg': {'mean': 0.0, 'std': 0.1}}
41+
}
42+
43+
lstm_base_config = ModelConfig(
44+
vocab_dim=vocab_size_without_blank,
45+
embed_dim=512,
46+
hidden_dim=2048,
47+
n_lstm_layers=2,
48+
use_bottle_neck=False,
49+
dropout=0.2,
50+
init_args=default_init_args,
51+
)
52+
53+
train_config_24gbgpu = {
54+
"optimizer": {"class": "SGD"},
55+
#############
56+
"batch_size": 1280, # BPE tokens
57+
"accum_grad_multiple_step": 1,
58+
"learning_rate": 1.0,
59+
"decay": 0.8,
60+
"multi_num_epochs": train_settings.train_partition_epoch,
61+
"relative_error_threshold": 0,
62+
"multi_update_interval": 1,
63+
"error_measure": "dev_ce",
64+
}
65+
66+
network_module = "lm.lstm.kazuki_lstm_zijian_variant_v1"
67+
train_args = {
68+
"config": train_config_24gbgpu,
69+
"network_module": network_module,
70+
"net_args": {"model_config_dict": asdict(lstm_base_config)},
71+
"debug": False,
72+
"add_cache_manager": True,
73+
}
74+
75+
training_name = prefix_name + "/" + network_module + ".512dim_sub6_24gbgpu_50eps"
76+
train_job = training(training_name, train_data_bpe10k, train_args, num_epochs=30, **default_returnn)
77+
train_job.rqmt["gpu_mem"] = 24

users/rossenbach/experiments/librispeech/ctc_rnnt_standalone_2024/pytorch_networks/lm/__init__.py

Whitespace-only changes.

users/rossenbach/experiments/librispeech/ctc_rnnt_standalone_2024/pytorch_networks/lm/lstm/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import torch
2+
from torch import nn
3+
4+
from .kazuki_lstm_zijian_variant_v1_cfg import ModelConfig
5+
6+
def mask_tensor(tensor: torch.Tensor, seq_len: torch.Tensor) -> torch.Tensor:
7+
"""
8+
mask a tensor with a "positive" mask (boolean true means position is used)
9+
10+
:param tensor: [B,T,....]
11+
:param seq_len: [B]
12+
:return: [B,T] as boolean
13+
"""
14+
seq_len = seq_len.to(device=tensor.device)
15+
r = torch.arange(tensor.shape[1], device=tensor.device) # [T]
16+
seq_mask = torch.less(r[None, :], seq_len[:, None]) # broadcast to [B,T]
17+
return seq_mask
18+
19+
20+
class Model(nn.Module):
21+
"""
22+
Simple LSTM LM with an embedding, an LSTM, and a final linear
23+
"""
24+
def __init__(self, model_config_dict, **kwargs):
25+
super().__init__()
26+
self.cfg = ModelConfig(**model_config_dict)
27+
if self.cfg.dropout > 0:
28+
self.dropout = nn.Dropout(p=self.cfg.dropout)
29+
else:
30+
self.dropout = None
31+
self.use_bottle_neck = self.cfg.use_bottle_neck
32+
self.embed = nn.Embedding(self.cfg.vocab_dim, self.cfg.embed_dim)
33+
self.lstm = nn.LSTM(
34+
input_size=self.cfg.embed_dim,
35+
hidden_size=self.cfg.hidden_dim,
36+
num_layers=self.cfg.n_lstm_layers,
37+
bias=self.cfg.bias,
38+
batch_first=True,
39+
dropout=self.cfg.dropout,
40+
bidirectional=False,
41+
)
42+
if self.cfg.use_bottle_neck:
43+
self.bottle_neck = nn.Linear(self.cfg.hidden_dim,self.cfg.bottle_neck_dim, bias=True)
44+
self.final_linear = nn.Linear(self.cfg.bottle_neck_dim, self.cfg.vocab_dim, bias=True)
45+
else:
46+
self.final_linear = nn.Linear(self.cfg.hidden_dim, self.cfg.vocab_dim, bias=True)
47+
self._param_init(**self.cfg.init_args)
48+
49+
50+
def _param_init(self, init_args_w=None, init_args_b=None):
51+
if init_args_w is None:
52+
init_args_w = {'func': 'normal', 'arg': {'mean': 0.0, 'std': 0.1}}
53+
if init_args_b is None:
54+
init_args_b = {'func': 'normal', 'arg': {'mean': 0.0, 'std': 0.1}}
55+
56+
for m in self.modules():
57+
58+
for name, param in m.named_parameters():
59+
if 'bias' in name:
60+
if init_args_b['func'] == 'normal':
61+
init_func = nn.init.normal_
62+
else:
63+
NotImplementedError
64+
hyp = init_args_b['arg']
65+
else:
66+
if init_args_w['func'] == 'normal':
67+
init_func = nn.init.normal_
68+
else:
69+
NotImplementedError
70+
hyp = init_args_w['arg']
71+
init_func(param, **hyp)
72+
73+
def forward(self, x):
74+
"""
75+
Return logits of each batch at each time step
76+
x: (B, S, F)
77+
"""
78+
x = self.embed(x)
79+
if self.dropout:
80+
x = self.dropout(x)
81+
batch_size = x.shape[0]
82+
h0 = torch.zeros((self.cfg.n_lstm_layers, batch_size, self.cfg.hidden_dim), device=x.device).detach()
83+
c0 = torch.zeros_like(h0, device=x.device).detach()
84+
# This is a uni-directional LSTM, so sequence masking is not necessary
85+
x, _ = self.lstm(x, (h0, c0))
86+
if self.dropout:
87+
x = self.dropout(x)
88+
if self.use_bottle_neck:
89+
x = self.bottle_neck(x)
90+
if self.dropout:
91+
x = self.dropout(x)
92+
x = self.final_linear(x)
93+
return x
94+
95+
96+
def train_step(*, model: Model, data, run_ctx, **kwargs):
97+
labels = data["ldata"]
98+
labels_len = data["data:size1"]
99+
delayed_labels = data["delayed"]
100+
101+
lm_logits = model(delayed_labels) # (B, S, F)
102+
103+
ce_loss = torch.nn.functional.cross_entropy(lm_logits.transpose(1, 2), labels, reduction='none')
104+
seq_mask = mask_tensor(labels, labels_len)
105+
ce_loss = (ce_loss * seq_mask).sum()
106+
total_length = torch.sum(labels_len)
107+
108+
run_ctx.mark_as_loss(name="ce", loss=ce_loss, inv_norm_factor=total_length)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from dataclasses import dataclass
2+
3+
from i6_models.config import ModelConfiguration
4+
5+
@dataclass
6+
class ModelConfig():
7+
vocab_dim: int
8+
embed_dim: int
9+
hidden_dim: int
10+
n_lstm_layers: int
11+
init_args: dict
12+
bias: bool = True
13+
use_bottle_neck: bool = False
14+
bottle_neck_dim: int = 512
15+
dropout: float = 0.0

0 commit comments

Comments
 (0)