Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
robin-p-schmitt committed Mar 7, 2025
1 parent f4aa1f2 commit eb90b0a
Show file tree
Hide file tree
Showing 6 changed files with 411 additions and 8 deletions.
17 changes: 15 additions & 2 deletions users/schmitt/corpus/segment_ends.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import zipfile
import copy
from typing import Dict, Any

from sisyphus import Path, Job, Task

from i6_core.returnn.config import ReturnnConfig
from i6_core.lib import corpus


class AugmentCorpusSegmentEndsJob(Job):
def __init__(self, bliss_corpous: Path, oggzip_path: Path):
def __init__(self, bliss_corpous: Path, oggzip_path: Path, corpus_key: str = "dev-other"):
self.bliss_corpus = bliss_corpous
self.oggzip_path = oggzip_path
self.corpus_key = corpus_key

self.out_bliss_corpus = self.output_path("corpus.xml.gz")

Expand All @@ -31,6 +35,15 @@ def run(self):

for segment in corpus_.segments():
assert segment.start == 0.0
segment.end = durations[f"dev-other/{segment.name}/{segment.name}"]
segment.end = durations[f"{self.corpus_key}/{segment.name}/{segment.name}"]

corpus_.dump(self.out_bliss_corpus.get_path())

@classmethod
def hash(cls, kwargs: Dict[str, Any]):
d = copy.deepcopy(kwargs)

if d["corpus_key"] == "dev-other":
d.pop("corpus_key")

return super().hash(d)
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,12 @@ def analyze_encoder(
param.trainable = True
rf.set_requires_gradient(data)

enc_numel = 0
for name, param in model.encoder.named_parameters():
enc_numel += param.raw_tensor.numel()

print("NUMBER OF ENCODER PARAMETERS: ", enc_numel)

ref_alignment_hdf = Path(config.typed_value("ref_alignment_hdf", str))
ref_alignment_blank_idx = config.typed_value("ref_alignment_blank_idx", int)
ref_alignment_vocab_path = Path(config.typed_value("ref_alignment_vocab_path", str))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,14 +336,66 @@ def get_rescore_config(self, opts: Dict):
python_prolog = copy.deepcopy(self.python_prolog)
python_epilog = copy.deepcopy(self.python_epilog)

dataset_opts = opts.get("dataset_opts", {})
config_dict["forward_data"] = {
"class": "MetaDataset",
"datasets": {
"zip_dataset": {
"audio": {
"features": "raw",
"peak_normalization": True,
"pre_process": None,
"preemphasis": None,
},
"class": "OggZipDataset",
"epoch_wise_filter": None,
"fixed_random_subset": None,
"partition_epoch": 1,
"path": [
"/u/zeineldeen/setups/librispeech/2022-11-28--conformer-att/work/i6_core/returnn/oggzip/BlissToOggZipJob.NSdIHfk1iw2M/output/out.ogg.zip"
],
"segment_file": None,
"seq_ordering": "sorted_reverse",
"targets": opts["vocab_opts"],
"use_cache_manager": True,
},
"hyps": {
"class": "TextDictDataset",
"filename": opts["n_best_path"],
"vocab": opts["vocab_opts"],
},
},
"data_map": {
"data": ("zip_dataset", "data"),
"data_flat": ("hyps", "data_flat"),
"data_seq_lens": ("hyps", "data_seq_lens"),
},
"seq_order_control_dataset": "hyps",
}
config_dict.update(dict(
task="forward",
search_output_layer="decision",
batching=opts.get("batching", "random")
batching=opts.get("batching", "random"),
target="data_flat"
))
config_dict.update(dict(forward_data=self.get_dataset(dataset_opts=dataset_opts, type_='search')))
extern_data_raw = self.get_extern_data_dict()

from returnn.tensor import Dim, batch_dim
_beam_dim = Dim(None, name="beam")
_data_flat_spatial_dim = Dim(None, name="data_flat_spatial")
extern_data_raw = {
"data_flat": {
"dims": [batch_dim, _data_flat_spatial_dim],
"dtype": "int32",
"vocab": opts["vocab_opts"],
},
"data_seq_lens": {"dims": [batch_dim, _beam_dim], "dtype": "int32"},
"data": {
"dim_tags": [
batch_dim,
Dim(None, name="time", kind=Dim.Types.Spatial),
Dim(1, name="audio", kind=Dim.Types.Feature),
]
},
}
extern_data_raw = instanciate_delayed(extern_data_raw)

config_dict["batch_size"] = opts.get("batch_size", 15_000) * self.batch_size_factor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .model.custom_load_params import load_missing_params_aed
from .analysis.analysis import analyze_encoder
from .analysis.gmm_alignments import setup_gmm_alignment, LIBRISPEECH_GMM_WORD_ALIGNMENT
# from .rescoring import rescore
from .rescoring import aed_rescore
from .configs import (
config_24gb_v1,
config_24gb_v2,
Expand Down Expand Up @@ -350,3 +350,77 @@ def py():
# returnn_python_exe=RETURNN_EXE,
# alias=f"models/{model_name}",
# )

# rescore 1k AED model output with 10k AED model and 10k AED model output with 1k AED model
from i6_experiments.users.zeyer.experiments.exp2024_04_23_baselines.interspeech_ctc_rescoring import \
NBestListReduceNJob, ApplyVocabToNBestListJob
from i6_core.returnn.search import SearchOutputRawReplaceJob
configs = []
configs.append(dict(
model_opts=config_24gb_v1["model_opts"],
checkpoint=PtCheckpoint(Path("/u/schmitt/experiments/segmental_models_2022_23_rf/work/i6_core/returnn/training/ReturnnTrainingJob.VNhMCmbnAUcd/output/models/epoch.2000.pt")),
alias="1k-aed",
vocab_opts=BPE1K_OPTS,
rescore_n_best_path=Path("/work/asr3/zeyer/schmitt/sisyphus_work_dirs/segmental_models_2022_23_rf/i6_core/returnn/forward/ReturnnForwardJobV2.39mQsZKFtWYG/output/output.py.gz"),
rescored_model_alias="10k-aed",
n_best_per_model=16,
own_n_best_path=Path("/work/asr3/zeyer/schmitt/sisyphus_work_dirs/segmental_models_2022_23_rf/i6_core/returnn/forward/ReturnnForwardJobV2.KHsw8m5r59Cc/output/output.py.gz"),
))

# configs.append(dict(
# model_opts=config_24gb_v1["model_opts"],
# checkpoint=PtCheckpoint(Path("/u/schmitt/experiments/segmental_models_2022_23_rf/work/i6_core/returnn/training/ReturnnTrainingJob.xEQKl4JvwUe4/output/models/epoch.300.pt")),
# alias="10k-aed",
# vocab_opts=BPE10K_OPTS,
# rescore_n_best_path=Path("/work/asr3/zeyer/schmitt/sisyphus_work_dirs/segmental_models_2022_23_rf/i6_core/returnn/forward/ReturnnForwardJobV2.KHsw8m5r59Cc/output/output.py.gz"),
# rescored_model_alias="1k-aed",
# n_best_per_model=16,
# own_n_best_path=Path("/work/asr3/zeyer/schmitt/sisyphus_work_dirs/segmental_models_2022_23_rf/i6_core/returnn/forward/ReturnnForwardJobV2.39mQsZKFtWYG/output/output.py.gz"),
# ))

for config in configs:
config_builder = AEDConfigBuilder(
dataset=LIBRISPEECH_CORPUS,
vocab_opts=config["vocab_opts"],
model_def=aed_model_def,
get_model_func=aed_get_model,
batch_size_factor=1,
)
config_builder.config_dict.update(config["model_opts"])
config_builder.config_dict["preload_from_files"] = dict(
pretrained_params=dict(
filename=config["checkpoint"],
ignore_missing=True,
custom_missing_load_func=load_missing_params_aed,
))
n_best_per_model = config["n_best_per_model"]
rescore_n_best_path = config["rescore_n_best_path"]
rescored_model_alias = config["rescored_model_alias"]
own_n_best_path = config["own_n_best_path"]

rescore_n_best_path = SearchOutputRawReplaceJob(
rescore_n_best_path, [("@@ ", "")], output_gzip=True).out_search_results
rescore_n_best_path = NBestListReduceNJob(rescore_n_best_path, new_n=n_best_per_model).out_returnn_n_best

own_n_best_path = SearchOutputRawReplaceJob(
own_n_best_path, [("@@ ", "")], output_gzip=True).out_search_results
own_n_best_path = NBestListReduceNJob(own_n_best_path, new_n=n_best_per_model).out_returnn_n_best

# rescored_n_best_path = aed_rescore(
# config_builder=config_builder,
# corpus_key="dev-other",
# checkpoint=None,
# returnn_root=RETURNN_ROOT,
# returnn_python_exe=RETURNN_EXE,
# vocab_opts={
# "bpe_file": "/work/asr4/zeineldeen/setups-data/librispeech/2022-11-28--conformer-att/work/i6_core/text/label/subword_nmt/train/ReturnnTrainBpeJob.qhkNn2veTWkV/output/bpe.codes",
# "class": "BytePairEncoding",
# # "seq_postfix": [0],
# "unknown_label": None,
# "vocab_file": "/work/asr4/zeineldeen/setups-data/librispeech/2022-11-28--conformer-att/work/i6_core/text/label/subword_nmt/train/ReturnnTrainBpeJob.qhkNn2veTWkV/output/bpe.vocab",
# "bos_label": 0,
# "eos_label": 0,
# },
# n_best_path=own_n_best_path, # rescore_n_best_path,
# alias=f"models/{config['alias']}/rescore-{rescored_model_alias}_{n_best_per_model}-best-per-model",
# )
Loading

0 comments on commit eb90b0a

Please sign in to comment.