Skip to content

Commit

Permalink
Add weighting for top-k dumping
Browse files Browse the repository at this point in the history
  • Loading branch information
mmueller00 committed Mar 7, 2025
1 parent 4536e91 commit 05f9013
Show file tree
Hide file tree
Showing 5 changed files with 326 additions and 116 deletions.
50 changes: 42 additions & 8 deletions users/mueller/datasets/librispeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from returnn.tensor import Tensor, Dim
from i6_experiments.users.zeyer.collect_model_dataset_stats import StatisticsOutput

from .utils import CorpusReplaceOrthFromPyDictJob, get_ogg_zip_dict_pseudo_labels, MetaDataset
from .utils import CorpusReplaceOrthFromPyDictJob, get_ogg_zip_dict_pseudo_labels, MetaDataset, GetScoresDummy

_alias_prefix = "datasets/LibriSpeech/"

Expand All @@ -48,7 +48,7 @@ def _get_librispeech_ogg_zip_dict() -> Dict[str, tk.Path]:


@cache
def _get_bliss_corpus_dict(pseudo_labels_path: tk.Path, part: str) -> Dict[str, tk.Path]:
def _get_bliss_corpus_dict(pseudo_labels_path: tk.Path, part: str, return_scores: bool = False) -> Dict[str, tk.Path]:
# Get Bliss corpus. Same audio format as in ogg_zip, so already there anyway due to how we created the ogg_zip.
# WARNING: Do not use these directly... It will keep another ogg copy of the audio...
# However, these are used later in the scoring, so when changing them, make sure it's optional,
Expand All @@ -58,21 +58,21 @@ def _get_bliss_corpus_dict(pseudo_labels_path: tk.Path, part: str) -> Dict[str,
bliss_corpus_dict = librispeech.get_bliss_corpus_dict(audio_format="ogg")
# load pseudo labels and replace here
bliss_corpus = bliss_corpus_dict[part]
replace_job = CorpusReplaceOrthFromPyDictJob(bliss_corpus, pseudo_labels_path)
replace_job = CorpusReplaceOrthFromPyDictJob(bliss_corpus, pseudo_labels_path, return_scores=return_scores)
replace_job.add_alias(os.path.join("datasets", "LibriSpeech-PseudoLabels", "%s_replace_orth" % part.replace('-', '_')))
bliss_corpus = replace_job.out_corpus
return {part: bliss_corpus}
return {part: bliss_corpus}, replace_job.scores_file if return_scores else None
else:
return librispeech.get_bliss_corpus_dict(audio_format="ogg")


@cache
def _get_librispeech_ogg_zip_dict_pseudo_labels(pseudo_labels_path: tk.Path, part: str) -> Dict[str, tk.Path]:
def _get_librispeech_ogg_zip_dict_pseudo_labels(pseudo_labels_path: tk.Path, part: str, return_scores: bool) -> Dict[str, tk.Path]:
# print("Convert pseudo labels to ogg")

bliss_corpus_dict = _get_bliss_corpus_dict(pseudo_labels_path, part)
bliss_corpus_dict, scores_dict = _get_bliss_corpus_dict(pseudo_labels_path, part, return_scores)

return get_ogg_zip_dict_pseudo_labels(bliss_corpus_dict)
return get_ogg_zip_dict_pseudo_labels(bliss_corpus_dict), scores_dict


@cache
Expand Down Expand Up @@ -314,6 +314,7 @@ def __init__(
train_ds_key: Optional[str] = None,
pseudo_label_path: tk.Path = None,
keep_small_labels: bool = False,
pseudo_nbest: Optional[int] = None,
):
"""
:param with_eos_postfix: For RETURNN train/dev/eval datasets, mostly relevant for training.
Expand Down Expand Up @@ -351,6 +352,7 @@ def __init__(
self.train_epoch_wise_filter = train_epoch_wise_filter
self.eval_subset = eval_subset
self.train_subset = train_subset
self.pseudo_nbest = pseudo_nbest

self._time_dim = None
self._feature_dim = None
Expand Down Expand Up @@ -387,6 +389,8 @@ def _sis_hash(self) -> bytes:
state.pop("train_subset")
if not self.keep_small_labels:
state.pop("keep_small_labels")
if self.pseudo_nbest is None or self.pseudo_nbest == 1:
state.pop("pseudo_nbest")
state = {k: v for k, v in state.items() if not k.startswith("_")}
byte_list = [b"LibrispeechOggZip", sis_hash_helper(state)]

Expand Down Expand Up @@ -419,6 +423,11 @@ def get_extern_data(self) -> Dict[str, Dict[str, Any]]:
"sparse_dim": self._classes_dim,
"vocab": self.vocab.get_opts(),
}

if self.pseudo_nbest is not None and self.pseudo_nbest > 1:
opts["weights"] = {
"dim_tags": [batch_dim, Dim(self.pseudo_nbest, name="pseudo_nbest")]
}

return opts

Expand Down Expand Up @@ -510,13 +519,20 @@ def get_dataset(self, key: str, *, training: bool = False, subset: Optional[int]
d["fixed_random_subset"] = subset # faster

# Combine pseudo labels into MetaDataset
return_scores = self.pseudo_nbest is not None and self.pseudo_nbest > 1
if training and self.pseudo_label_path:
files_new = []
score_files = []
for part in parts:
if part == "train-clean-100" and self.keep_small_labels:
files_new += [_get_librispeech_ogg_zip_dict()[part]]
if return_scores:
score_files += [GetScoresDummy(_get_bliss_corpus_dict(None, None)[part], self.pseudo_nbest).scores_file]
else:
files_new += [_get_librispeech_ogg_zip_dict_pseudo_labels(self.pseudo_label_path, part)[part]]
ogg_files, scores = _get_librispeech_ogg_zip_dict_pseudo_labels(self.pseudo_label_path, part, return_scores)
files_new += [ogg_files[part]]
if return_scores:
score_files += [scores]
d_pseudo = copy(d)
d.pop("fixed_random_subset", None)
d_pseudo["audio"] = None
Expand All @@ -526,7 +542,25 @@ def get_dataset(self, key: str, *, training: bool = False, subset: Optional[int]
"data": ("zip_dataset", "data"),
"classes": ("pseudo_labels_dataset", "classes"),
}
if return_scores:
d_weights = {
"class": "HDFDataset",
"files": score_files,
"use_cache_manager": True,
}
d_comb["weights_datasets"] = d_weights
data_map["weights"] = ("weights_datasets", "data")
d = MetaDataset(data_map, d_comb, "pseudo_labels_dataset").as_returnn_opts()
elif return_scores:
score_files = []
for part in parts:
score_files += [GetScoresDummy(_get_bliss_corpus_dict(None, None)[part], self.pseudo_nbest).scores_file]
d_weights = {
"class": "HDFDataset",
"files": score_files,
"use_cache_manager": True,
}
d = MetaDataset({"data": ("zip_dataset", "data"), "classes": ("zip_dataset", "classes"), "weights": ("weights_datasets", "data")}, {"zip_dataset": d, "weights_datasets": d_weights}, "zip_dataset").as_returnn_opts()
return d

class LibrispeechLmDataset(DatasetConfig):
Expand Down
76 changes: 73 additions & 3 deletions users/mueller/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import os
import gzip
import numpy as np

from typing import Dict, Tuple, Union, Any, Optional, Sequence

from i6_core.lib import corpus
from sisyphus import tools as sis_tools
from sisyphus import Job, Task as SisTask, tk
from i6_core.util import uopen
from i6_core.lib.hdf import get_returnn_simple_hdf_writer

class CorpusReplaceOrthFromPyDictJob(Job):
"""
Merge HDF pseudo labels back into a bliss corpus
"""

def __init__(self, bliss_corpus, recog_words_file, segment_file=None):
def __init__(self, bliss_corpus, recog_words_file, segment_file=None, return_scores=False):
"""
:param Path bliss_corpus: Bliss corpus
:param Path recog_words_file: a recog_words file
Expand All @@ -22,6 +26,10 @@ def __init__(self, bliss_corpus, recog_words_file, segment_file=None):
self.segment_file = segment_file

self.out_corpus = self.output_path("corpus.xml.gz")
self.scores_file = None
if return_scores:
self.scores_file = self.output_path("scores.hdf")
self.return_scores = return_scores

def tasks(self):
yield SisTask("run", rqmt={"cpu": 4, "mem": 8, "time": 4})
Expand All @@ -46,6 +54,9 @@ def run(self):
assert isinstance(d, dict), "only search output file with dict format is supported"

j = 0
if self.return_scores:
SimpleHDFWriter = get_returnn_simple_hdf_writer(None)
out_hdf = SimpleHDFWriter(filename=self.scores_file.get_path(), dim=None)
for segment in segment_iterator:
assert segment.fullname() in d, f"Segment {segment.fullname()} not in search output"
line = d[segment.fullname()]
Expand All @@ -57,28 +68,87 @@ def run(self):
j += 1
else:
if isinstance(line, list):
assert self.return_scores
lines = []
scores = []
for e in line:
new_str = e[1].strip()
if new_str:
if new_str in lines:
raise ValueError(f"Duplicate pseudo label {new_str} in segment {segment.fullname()}")
else:
lines.append(new_str)
scores.append(e[0])
else:
print(f"Empty pseudo label in segment {segment.fullname()}")
lines.append("")
scores.append(e[0])
line = " ZZZZZ ".join(lines)
if len(lines) != 2:
print(f"Segment {segment.fullname()} does not have enough pseudo labels. ({line})")
out_hdf.insert_batch(
inputs=np.array(scores, dtype=np.float32).reshape(1, -1),
seq_len=[len(scores)],
seq_tag=[segment.fullname()],
)
segment.orth = line.strip()
n = len(c.recordings)
m = len(d)
assert m == n + j, f"Number of segments in corpus ({n+j}) does not match number of segments in search output ({m})"

if self.return_scores:
out_hdf.close()

print(f"Number of segments with empty pseudo label: {j} out of {m}, Percentage: {j/m}")
c.dump(self.out_corpus.get_path())

@classmethod
def hash(cls, parsed_args: Dict[str, Any]) -> str:
"""
:param parsed_args:
:return: hash for job given the arguments
"""
# Extend the default hash() function.
d = parsed_args.copy()
if not d["return_scores"]:
d.pop("return_scores")

return sis_tools.sis_hash(d)

class GetScoresDummy(Job):
"""
Creates a dummy with scores for corpus without pseudo labels
"""

def __init__(self, bliss_corpus: tk.Path, pseudo_nbest: int):
"""
:param Path bliss_corpus: Bliss corpus
"""
self.bliss_corpus = bliss_corpus
self.pseudo_nbest = pseudo_nbest
self.scores_file = self.output_path("dummy_scores.hdf")

def tasks(self):
yield SisTask("run", rqmt={"cpu": 4, "mem": 8, "time": 4})

def run(self):
c = corpus.Corpus()
c.load(self.bliss_corpus.get_path())
n = len(list(c.segments()))
scores = [0.0] + [-1e30] * (self.pseudo_nbest - 1)
scores = [scores] * n
scores = np.array(scores, dtype=np.float32)
tags = [segment.fullname() for segment in c.segments()]

assert scores.shape == (n, self.pseudo_nbest)

SimpleHDFWriter = get_returnn_simple_hdf_writer(None)
out_hdf = SimpleHDFWriter(filename=self.scores_file.get_path(), dim=None)
out_hdf.insert_batch(
inputs=scores,
seq_len=[self.pseudo_nbest] * n,
seq_tag=tags,
)
out_hdf.close()

def get_ogg_zip_dict_pseudo_labels(bliss_corpus_dict: Dict[str, tk.Path]) -> Dict[str, tk.Path]:
from i6_core.returnn.oggzip import BlissToOggZipJob
import os
Expand Down
Loading

0 comments on commit 05f9013

Please sign in to comment.