|
| 1 | +from sisyphus import tk |
| 2 | +from sisyphus.delayed_ops import DelayedFormat |
| 3 | + |
| 4 | +from dataclasses import dataclass |
| 5 | +import os |
| 6 | +from typing import Any, Dict, Optional |
| 7 | + |
| 8 | +from i6_core.text.label.subword_nmt.apply import ApplyBPEToTextJob |
| 9 | +from i6_core.corpus.convert import CorpusToTxtJob |
| 10 | +from i6_core.text.processing import ConcatenateJob |
| 11 | +from i6_core.returnn.config import CodeWrapper |
| 12 | + |
| 13 | +from i6_experiments.common.setups.returnn.datasets import MetaDataset, ControlDataset, Dataset |
| 14 | +from i6_experiments.common.setups.returnn.datastreams.base import Datastream |
| 15 | +from i6_experiments.common.setups.returnn.datastreams.vocabulary import BpeDatastream |
| 16 | +from i6_experiments.common.helpers.text_labels.subword_nmt_bpe import get_returnn_subword_nmt |
| 17 | + |
| 18 | +from i6_experiments.common.datasets.librispeech import get_bliss_corpus_dict |
| 19 | +from i6_experiments.common.datasets.librispeech.vocab import get_subword_nmt_bpe_v2 |
| 20 | +from i6_experiments.common.datasets.librispeech.language_model import get_librispeech_normalized_lm_data |
| 21 | + |
| 22 | + |
| 23 | + |
| 24 | +SOURCE_DATASTREAM_KEY = "data" |
| 25 | +TARGET_DATASTREAN_KEY = "delayed" |
| 26 | + |
| 27 | + |
| 28 | +@dataclass(frozen=True) |
| 29 | +class TrainingDatasets: |
| 30 | + train: Dataset |
| 31 | + cv: Dataset |
| 32 | + devtrain: Dataset |
| 33 | + datastreams: Dict[str, Datastream] |
| 34 | + |
| 35 | + |
| 36 | +class LmDataset(ControlDataset): |
| 37 | + |
| 38 | + def __init__( |
| 39 | + self, |
| 40 | + *, |
| 41 | + corpus_file: tk.Path, |
| 42 | + vocab_file: tk.Path, |
| 43 | + # super parameters |
| 44 | + partition_epoch: Optional[int] = None, |
| 45 | + segment_file: Optional[tk.Path] = None, |
| 46 | + seq_ordering: Optional[str] = None, |
| 47 | + random_subset: Optional[int] = None, |
| 48 | + additional_options: Optional[Dict] = None, |
| 49 | + ): |
| 50 | + super().__init__( |
| 51 | + partition_epoch=partition_epoch, |
| 52 | + segment_file=segment_file, |
| 53 | + seq_ordering=seq_ordering, |
| 54 | + random_subset=random_subset, |
| 55 | + additional_options=additional_options |
| 56 | + ) |
| 57 | + |
| 58 | + self.corpus_file = corpus_file |
| 59 | + self.vocab_file = vocab_file |
| 60 | + |
| 61 | + def as_returnn_opts(self) -> Dict[str, Any]: |
| 62 | + d = { |
| 63 | + "class": "LmDataset", |
| 64 | + "corpus_file": CodeWrapper(DelayedFormat('lambda: cf("{}")', self.corpus_file)), |
| 65 | + "orth_symbols_map_file": self.vocab_file, |
| 66 | + "orth_replace_map_file": "", |
| 67 | + "word_based": True, |
| 68 | + "seq_end_symbol": "</s>", |
| 69 | + "auto_replace_unknown_symbol": False, |
| 70 | + "unknown_symbol": "<unk>", |
| 71 | + "add_delayed_seq_data": True, |
| 72 | + "delayed_seq_data_start_symbol": "<s>", |
| 73 | + } |
| 74 | + sd = super().as_returnn_opts() |
| 75 | + assert all([k not in sd.keys() for k in d.keys()]), ( |
| 76 | + "conflicting keys in %s and %s" |
| 77 | + % (str(list(sd.keys())), str(list(d.keys()))), |
| 78 | + ) |
| 79 | + d.update(sd) |
| 80 | + |
| 81 | + return d |
| 82 | + |
| 83 | +@dataclass() |
| 84 | +class LMDatasetSettings: |
| 85 | + train_partition_epoch: int |
| 86 | + train_seq_ordering: str |
| 87 | + |
| 88 | + |
| 89 | +def get_subword_repo(): |
| 90 | + """ |
| 91 | + This is a for now very ugly helper to get the same subword_nmt repo |
| 92 | + as the get_subword_nmt_bpe_v2 is using |
| 93 | + :return: |
| 94 | + """ |
| 95 | + subword_nmt_repo = get_returnn_subword_nmt( |
| 96 | + commit_hash="5015a45e28a958f800ef1c50e7880c0c9ef414cf", output_prefix="" |
| 97 | + ) |
| 98 | + # overwrite hash for future bugfixes, it is unlikely the logic will ever be changed |
| 99 | + subword_nmt_repo.hash_overwrite = "I6_SUBWORD_NMT_V2" |
| 100 | + return subword_nmt_repo |
| 101 | + |
| 102 | +def build_lm_training_datasets(prefix, librispeech_key, bpe_size, settings: LMDatasetSettings): |
| 103 | + |
| 104 | + #data_map = {SOURCE_DATASTREAM_KEY: ("lm_dataset", "data"), TARGET_DATASTREAN_KEY: ("lm_dataset", "delayed")} |
| 105 | + #def make_meta(dataset: LmDataset): |
| 106 | + # return MetaDataset( |
| 107 | + # data_map=data_map, datasets={"lm_dataset": dataset}, seq_order_control_dataset="lm_dataset" |
| 108 | + # ) |
| 109 | + |
| 110 | + bpe_settings = get_subword_nmt_bpe_v2(corpus_key=librispeech_key, bpe_size=bpe_size, unk_label='<unk>') |
| 111 | + ls_bliss_corpus_dict = get_bliss_corpus_dict() |
| 112 | + bpe_datastream = BpeDatastream(available_for_inference=False, bpe_settings=bpe_settings) |
| 113 | + |
| 114 | + #### Training Data #### |
| 115 | + |
| 116 | + lm_data = get_librispeech_normalized_lm_data() |
| 117 | + ls_train_bliss = ls_bliss_corpus_dict["train-other-960"] |
| 118 | + ls_train_text = CorpusToTxtJob( |
| 119 | + bliss_corpus=ls_train_bliss, |
| 120 | + gzip=True, |
| 121 | + ).out_txt |
| 122 | + full_train_text = ConcatenateJob( |
| 123 | + text_files=[lm_data, ls_train_text], |
| 124 | + zip_out=True, |
| 125 | + ).out |
| 126 | + lm_bpe_data_job = ApplyBPEToTextJob( |
| 127 | + text_file=full_train_text, |
| 128 | + bpe_codes=bpe_settings.bpe_codes, |
| 129 | + bpe_vocab=bpe_settings.bpe_count_vocab, |
| 130 | + gzip_output=True, |
| 131 | + subword_nmt_repo=get_subword_repo(), |
| 132 | + mini_task=False, # this is a large file, so run in cluster |
| 133 | + ) |
| 134 | + lm_bpe_data_job.add_alias(os.path.join(prefix, "apply_bpe_to_train")) |
| 135 | + |
| 136 | + #### Dev Data #### |
| 137 | + |
| 138 | + dev_clean_text = CorpusToTxtJob(bliss_corpus=ls_bliss_corpus_dict["dev-clean"], gzip=True).out_txt |
| 139 | + dev_other_text = CorpusToTxtJob(bliss_corpus=ls_bliss_corpus_dict["dev-other"], gzip=True).out_txt |
| 140 | + cv_text = ConcatenateJob( |
| 141 | + text_files=[dev_clean_text, dev_other_text], |
| 142 | + zip_out=True, |
| 143 | + ).out |
| 144 | + cv_bpe_data_job = ApplyBPEToTextJob( |
| 145 | + text_file=cv_text, |
| 146 | + bpe_codes=bpe_settings.bpe_codes, |
| 147 | + bpe_vocab=bpe_settings.bpe_count_vocab, |
| 148 | + gzip_output=True, |
| 149 | + subword_nmt_repo=get_subword_repo(), |
| 150 | + ) |
| 151 | + |
| 152 | + #### datasets #### |
| 153 | + lm_train_dataset = LmDataset( |
| 154 | + corpus_file=lm_bpe_data_job.out_bpe_text, |
| 155 | + vocab_file=bpe_settings.bpe_vocab, |
| 156 | + partition_epoch=settings.train_partition_epoch, |
| 157 | + segment_file=None, |
| 158 | + seq_ordering=settings.train_seq_ordering |
| 159 | + ) |
| 160 | + |
| 161 | + lm_cv_dataset = LmDataset( |
| 162 | + corpus_file=cv_bpe_data_job.out_bpe_text, |
| 163 | + vocab_file=bpe_settings.bpe_vocab, |
| 164 | + partition_epoch=1, |
| 165 | + segment_file=None, |
| 166 | + seq_ordering="sorted" |
| 167 | + ) |
| 168 | + |
| 169 | + lm_devtrain_dataset = LmDataset( |
| 170 | + corpus_file=lm_bpe_data_job.out_bpe_text, |
| 171 | + vocab_file=bpe_settings.bpe_vocab, |
| 172 | + partition_epoch=1, |
| 173 | + segment_file=None, |
| 174 | + seq_ordering="sorted", |
| 175 | + random_subset=3000, |
| 176 | + ) |
| 177 | + |
| 178 | + return TrainingDatasets( |
| 179 | + train=lm_train_dataset, |
| 180 | + cv=lm_cv_dataset, |
| 181 | + # devtrain=lm_devtrain_dataset, |
| 182 | + # TODO: Ultra hack for now |
| 183 | + devtrain=lm_cv_dataset, |
| 184 | + datastreams={"data": bpe_datastream, "delayed": bpe_datastream}, |
| 185 | + ) |
| 186 | + |
0 commit comments