Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add jobs to study alignments #558

Open
wants to merge 43 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
9a7506c
Fix DumpAlignmentJob flow
Icemole Nov 19, 2024
e07caff
Add job to dump text/alignment pairs for all segments
Icemole Nov 19, 2024
57db395
Revert changes in branch to main's
Icemole Nov 19, 2024
258b31d
Black
Icemole Nov 19, 2024
528bb29
Remove job from wrong location
Icemole Nov 19, 2024
dd5bec0
Add job to correct location, add PlotViterbiAlignmentJob
Icemole Nov 19, 2024
8217e5d
Fixes
Icemole Nov 19, 2024
597dfa4
More fixes
Icemole Nov 19, 2024
464ebc8
Fix when alignment is empty
Icemole Nov 19, 2024
cbdab57
Black
Icemole Nov 19, 2024
b61e063
Add file for faulty/empty alignment seqtags
Icemole Nov 19, 2024
b26a326
Work
Icemole Nov 20, 2024
15df7fe
Remove original author from docstring
Icemole Nov 20, 2024
4a703f7
PlotViterbiAlignmentJob: add functionality to plot subset of seq tags
Icemole Nov 20, 2024
cc07c0c
DumpSegmentTextAlignmentJob: always compress output csv
Icemole Nov 20, 2024
dbc4b09
DumpSegmentTextAlignmentJob: add functionality to plot subset of seq …
Icemole Nov 20, 2024
3c85bdd
More work
Icemole Nov 20, 2024
6ed4814
Fix uopen call
Icemole Nov 20, 2024
d155c35
Don't interpolate plot
Icemole Nov 20, 2024
7bb0009
alignment_files -> alignment_caches
Icemole Nov 20, 2024
899a68a
Add full orth function
Icemole Dec 10, 2024
6db360c
Black
Icemole Dec 10, 2024
3143196
Revert _orth change
Icemole Dec 10, 2024
7a550d4
Shorten full_orth code
Icemole Dec 10, 2024
f34fcfd
Merge remote-tracking branch 'origin/main' into segment-text-alignmen…
Icemole Dec 18, 2024
31c6488
Fix race condition in DumpSegmentTextAlignmentJob
Icemole Dec 18, 2024
ba20708
Merge remote-tracking branch 'origin/main' into segment-text-alignmen…
Icemole Dec 18, 2024
0b16c52
Merge remote-tracking branch 'origin/use-orth-function' into segment-…
Icemole Dec 18, 2024
923a821
Use full orth instead of only orth
Icemole Dec 18, 2024
5754ad7
Merge remote-tracking branch 'origin/main' into segment-text-alignmen…
Icemole Dec 18, 2024
0104ff5
Parallelize or not
Icemole Feb 5, 2025
5c1c1cd
Merge remote-tracking branch 'origin/main' into segment-text-alignmen…
Icemole Feb 5, 2025
7fd2cd8
Fix imports
Icemole Feb 5, 2025
18ccc5a
Black
Icemole Feb 5, 2025
d867945
Add `pretty_format_orth` parameter to `Corpus.load()`
Icemole Feb 21, 2025
ed5e643
Fix default param value
Icemole Feb 21, 2025
8690cc6
Fix variable not created before
Icemole Feb 21, 2025
bab9d8a
Address reviewers' comments
Icemole Feb 24, 2025
ea826d1
Merge remote-tracking branch 'origin/main' into segment-text-alignmen…
Icemole Feb 24, 2025
ca7dd7a
Prevent spaces from accumulating to the left/right of the orth
Icemole Feb 24, 2025
b5b8fa4
Merge remote-tracking branch 'origin/pretty-format-orth' into segment…
Icemole Feb 24, 2025
112d95b
Merge remote-tracking branch 'origin/main' into segment-text-alignmen…
Icemole Mar 18, 2025
8e28f11
Fix black
Icemole Mar 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
308 changes: 303 additions & 5 deletions mm/alignment.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
__all__ = [
"get_seq_tag_to_alignment_mapping",
"AlignmentJob",
"DumpAlignmentJob",
"PlotAlignmentJob",
"AMScoresFromAlignmentLogJob",
"ComputeTimeStampErrorJob",
"GetLongestAllophoneFileJob",
"DumpSegmentTextAlignmentJob",
"PlotViterbiAlignmentJob",
]

import itertools
Expand All @@ -13,20 +16,41 @@
import os
import shutil
import statistics
from typing import Callable, Counter, Dict, Iterable, List, Optional, Tuple, Union
import xml.etree.ElementTree as ET
from typing import Callable, Counter, List, Optional, Tuple, Union

from sisyphus import *

Path = setup_path(__package__)
import numpy as np
from sisyphus import Job, Task, setup_path, tk

import i6_core.lib.corpus as corpus
import i6_core.lib.rasr_cache as rasr_cache
import i6_core.rasr as rasr
import i6_core.util as util

from .flow import alignment_flow, dump_alignment_flow


Path = setup_path(__package__)


_SeqTagToAlignmentType = Dict[str, List[Tuple[int, int, int, float]]]


def get_seq_tag_to_alignment_mapping(
alignment_cache: rasr_cache.FileArchive,
) -> _SeqTagToAlignmentType:
"""
:param alignment_cache: Opened alignment cache from which to extract the alignments.
:return: Mapping from sequence tags to alignments (by frame).
The alignments are a list of tuples (timestamp, allophone_id, hmm_state, alignment_weight).
"""
return {
seq_tag: alignment_cache.read(seq_tag, "align")
for seq_tag in alignment_cache.ft.keys()
if not seq_tag.endswith(".attribs")
}


class AlignmentJob(rasr.RasrCommand, Job):
"""
Align a dataset with the given feature scorer.
Expand Down Expand Up @@ -149,7 +173,6 @@ def run(self, task_id):
)

def plot(self):
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

Expand Down Expand Up @@ -819,3 +842,278 @@ def run(self):
line_set = {*lines} - {None}
assert len(line_set) == 1, f"Line {i}: expected only one allophone, but found two or more: {line_set}."
f.write(list(line_set)[0])


class DumpSegmentTextAlignmentJob(Job):
"""
Dumps all text and alignments for the given corpus and alignment files
in a human-readable format defined as follows:
```
<seq-tag>
<text>
<alignment-index-0> <start-0> <end-0> <triphone-0> <weight-0>
<alignment-index-1> <start-1> <end-1> <triphone-1> <weight-1>
...
```
"""

def __init__(
self,
corpus_file: tk.Path,
alignment_caches: Iterable[tk.Path],
allophone_file: tk.Path,
seq_tags_to_dump: Optional[tk.Path] = None,
frame_size: float = 0.25,
frame_step: float = 0.1,
):
"""
:param corpus_file: Corpus file to get the text from.
:param alignment_caches: Alignment files to get the alignments from.
Must correspond to the corpus given in :param:`corpus_file` for the job to work properly.
:param allophone_file: Allophone file with which the alignments given in :param:`alignment_caches` were dumped.
:param seq_tags_to_dump: Specific sequence tags to dump.
By default, dump all sequences given in :param:`alignment_caches`.
:param frame_size: Frame size. Only used to calculate the timestamps of the alignments.
:param frame_step: Frame step. Only used to calculate the timestamps of the alignments.
"""
self.corpus_file = corpus_file
self.alignment_caches = alignment_caches
self.allophone_file = allophone_file
self.seq_tags_to_dump = seq_tags_to_dump
self.frame_size = frame_size
self.frame_step = frame_step

self.out_text_alignment_pairs = self.output_path("segment_txt_alignment.txt.gz")

self.rqmt = {"cpu": 1, "mem": 2.0, "time": 1.0}

def tasks(self):
if self.seq_tags_to_dump:
# Do not parallelize.
yield Task("run", resume="run", rqmt=self.rqmt)
else:
# Parallelize.
yield Task("run", resume="run", rqmt=self.rqmt, args=range(1, len(self.alignment_caches) + 1))
yield Task("merge", resume="merge")

def run(self, task_id: Optional[int] = None):
# Get the alignment information: seq_tag -> alignment.
if self.seq_tags_to_dump is not None:
# Load the seq tags to plot.
seq_tags_to_dump = set()
with util.uopen(self.seq_tags_to_dump.get_path(), "rt") as f:
for seq_tag in f:
seq_tags_to_dump.add(seq_tag.strip())
# Load the seq tags to plot from the alignment caches.
seq_tag_to_alignments = {}
for alignment_cache in self.alignment_caches:
align_cache = rasr_cache.FileArchive(alignment_cache.get_path())
align_cache.setAllophones(self.allophone_file.get_path())
for seq_tag, alignments in get_seq_tag_to_alignment_mapping(align_cache).items():
# Only load the specific seq tags that we've already found
if seq_tag in seq_tags_to_dump:
seq_tag_to_alignments[seq_tag] = alignments
# Check that all sequences provided by the user are in the alignments.
for seq_tag in seq_tags_to_dump:
assert (
seq_tag in seq_tag_to_alignments
), f"The sequence tag {seq_tag} provided in seq_tags_to_dump is not in the provided alignment files."
else:
# Load specific task_id alignment cache.
align_cache = rasr_cache.FileArchive(self.alignment_caches[task_id - 1].get_path())
align_cache.setAllophones(self.allophone_file.get_path())
seq_tag_to_alignments = get_seq_tag_to_alignment_mapping(align_cache)
# Plot everything from the local alignment cache.
seq_tags_to_dump = seq_tag_to_alignments.keys()

# Get the corpus information: seq_tag -> text.
c = corpus.Corpus()
c.load(self.corpus_file.get_path())
seq_tag_to_text = {seq_tag: segment.full_orth() for seq_tag, segment in c.get_segment_mapping().items()}

if self.seq_tags_to_dump is not None:
output_file = self.out_text_alignment_pairs.get_path()
else:
output_file = f"intermediate_segment_txt_alignment.{task_id}.txt.gz"
with util.uopen(output_file, "wt") as f:
for seq_tag in set(seq_tags_to_dump).intersection(set(seq_tag_to_text.keys())):
res = f"{seq_tag}\n"
res += f"{seq_tag_to_text[seq_tag]}\n"
for align_idx, allo_id, hmm_state, weight in seq_tag_to_alignments[seq_tag]:
res += (
f"{align_idx} "
f"{(self.frame_step * align_idx):.3f} "
f"{(self.frame_step * align_idx + self.frame_size):.3f} "
f"{align_cache.allophones[allo_id]}.{hmm_state} "
f"{weight:.3f}\n"
)
res += "\n"
f.write(res)

def merge(self):
with util.uopen(self.out_text_alignment_pairs.get_path(), "wt") as f_out:
for i in range(1, len(self.alignment_caches) + 1):
with util.uopen(f"intermediate_segment_txt_alignment.{i}.txt.gz", "rt") as f_in:
for line in f_in:
f_out.write(line)


class PlotViterbiAlignmentJob(Job):
"""
Plots the alignments of each segment in the specified alignment files.
"""

def __init__(
self,
alignment_caches: Iterable[tk.Path],
allophone_file: tk.Path,
seq_tags_to_plot: Optional[tk.Path] = None,
corpus_file: Optional[tk.Path] = None,
):
"""
:param alignment_caches: Alignment files to be plotted.
:param allophone_file: Allophone file used in the alignment process.
:param seq_tags_to_plot: Specific sequence tags to plot.
By default, plot all sequences given in :param:`alignment_caches`.
:param corpus_file: Corpus used to generate the alignments. By default, the plots have no title.
If provided, the plots will have the text from the respective segment as title,
whenever the segment is available in the corpus. This should only be given for convenience.
"""
self.alignment_caches = alignment_caches
self.allophone_file = allophone_file
self.seq_tags_to_plot = seq_tags_to_plot
self.corpus_file = corpus_file

self.out_plot_dir = self.output_path("plots", directory=True)

self.rqmt = {"cpu": 1, "mem": 2.0, "time": 1.0}

def tasks(self):
if self.seq_tags_to_plot:
# Do not parallelize.
yield Task("run", resume="run", rqmt=self.rqmt)
else:
# Parallelize.
yield Task("run", resume="run", rqmt=self.rqmt, args=range(1, len(self.alignment_caches) + 1))

def extract_phoneme_sequence(self, alignment: np.array) -> Tuple[np.array, np.array]:
"""
:param alignment: Monophone alignment, for instance: `np.array(["a", "a", "b", ...])`.
:return: Monophone sequence (ordered as given),
as well as the indices corresponding to the monophone sequence from the Viterbi alignment.
"""
boundaries = np.concatenate(
[
np.where(alignment[:-1] != alignment[1:])[0],
[len(alignment) - 1], # manually add boundary of last allophone
]
)

lengths = boundaries - np.concatenate([[-1], boundaries[:-1]])
phonemes = alignment[boundaries]
monotonic_idx_alignment = np.repeat(np.arange(len(phonemes)), lengths)
return phonemes, monotonic_idx_alignment

def make_viterbi_matrix(self, label_idx_seq: np.array) -> np.array:
"""
:return: Matrix corresponding to the Viterbi alignment.
"""
num_alignments = len(label_idx_seq)
max_timestamp = max(label_idx_seq) + 1
viterbi_matrix = np.zeros((max_timestamp, num_alignments), dtype=np.float32)
for t, idx in enumerate(label_idx_seq):
viterbi_matrix[idx, t] = 1.0
return viterbi_matrix

def plot(self, viterbi_matrix: np.array, allophone_sequence: List[str], file_name: str, title: str = ""):
"""
:param viterbi_matrix: Matrix to be plotted, corresponding to the Viterbi alignment.
:param allophone_sequence: Allophone sequence (Y-axis tick labels).
:param file_name: File name where to store the plot, relative to `<job>/output/plots/`.
:param title: Optional title to add to the image. By default there will be no title.
:return: Plot corresponding to the monotonic alignment.
"""
import matplotlib
import matplotlib.pyplot as plt

matplotlib.use("Agg")

max_timestamp, num_alignments = np.shape(viterbi_matrix)

fig, ax = plt.subplots(figsize=(10, 10))
ax.set_xlabel("Frame")
ax.xaxis.set_label_coords(0.98, -0.03)
ax.set_xbound(0, num_alignments - 1)
ax.set_ybound(-0.5, max_timestamp - 0.5)

ax.set_yticks(np.arange(max_timestamp))
ax.set_yticklabels(allophone_sequence)

ax.set_title(title)

ax.imshow(viterbi_matrix, cmap="Blues", interpolation="none", aspect="auto", origin="lower")

# The plot will be purposefully divided into subdirectories.
os.makedirs(os.path.dirname(os.path.join(self.out_plot_dir.get_path(), file_name)), exist_ok=True)
fig.savefig(os.path.join(self.out_plot_dir.get_path(), f"{file_name}.png"))
matplotlib.pyplot.close(fig)

def run(self, task_id: Optional[int] = None):
if self.seq_tags_to_plot is not None:
# Load the seq tags to plot.
seq_tags_to_plot = set()
with util.uopen(self.seq_tags_to_plot.get_path(), "rt") as f:
for seq_tag in f:
seq_tags_to_plot.add(seq_tag.strip())
# Load the seq tags to plot from the alignment caches.
seq_tag_to_alignments = {}
for alignment_cache in self.alignment_caches:
align_cache = rasr_cache.FileArchive(alignment_cache.get_path())
align_cache.setAllophones(self.allophone_file.get_path())
for seq_tag, alignments in get_seq_tag_to_alignment_mapping(align_cache).items():
# Only load the specific seq tags that we've already found
if seq_tag in seq_tags_to_plot:
seq_tag_to_alignments[seq_tag] = alignments
# Check that all sequences provided by the user are in the alignments.
for seq_tag in seq_tags_to_plot:
assert (
seq_tag in seq_tag_to_alignments
), f"The sequence tag {seq_tag} provided in seq_tags_to_plot is not in the provided alignment files."
else:
# Load specific task_id alignment cache.
align_cache = rasr_cache.FileArchive(self.alignment_caches[task_id - 1].get_path())
align_cache.setAllophones(self.allophone_file.get_path())
seq_tag_to_alignments = get_seq_tag_to_alignment_mapping(align_cache)
# Plot everything from the local alignment cache.
seq_tags_to_plot = seq_tag_to_alignments.keys()

seq_tag_to_text = {}
if self.corpus_file is not None:
c = corpus.Corpus()
c.load(self.corpus_file.get_path())
seq_tag_to_text = {seq_tag: segment.full_orth() for seq_tag, segment in c.get_segment_mapping().items()}

empty_alignment_seq_tags = []
for seq_tag in seq_tags_to_plot:
alignments = seq_tag_to_alignments[seq_tag]
# In some rare cases, the alignment doesn't have to reach a satisfactory end.
# In these cases, the final alignment is empty. Skip those cases.
if len(alignments) == 0:
empty_alignment_seq_tags.append(seq_tag)
continue

for i, (timestamp, allo_id, hmm_state, weight) in enumerate(alignments):
allophone = align_cache.allophones[allo_id]
# Get the central part of the allophone.
seq_tag_to_alignments[seq_tag][i] = allophone.split("{")[0]

center_allophones = np.array(seq_tag_to_alignments[seq_tag])
phonemes, alignment_indices = self.extract_phoneme_sequence(center_allophones)
viterbi_matrix = self.make_viterbi_matrix(alignment_indices)
self.plot(viterbi_matrix, phonemes, file_name=seq_tag, title=seq_tag_to_text.get(seq_tag, ""))

if empty_alignment_seq_tags:
logging.warning(
"The following alignments weren't plotted because their alignments were empty:\n"
f"{empty_alignment_seq_tags}"
)