Skip to content

Commit

Permalink
update raissi
Browse files Browse the repository at this point in the history
  • Loading branch information
Marvin84 committed Mar 6, 2025
1 parent 8c16881 commit c79d713
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 198 deletions.
25 changes: 18 additions & 7 deletions users/raissi/setups/common/TF_factored_hybrid_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,18 +250,26 @@ def get_conformer_network_zhou_variant(
spec_augment_as_data: bool = True,
auxilary_loss_layers: list = [6],
frame_rate_reduction_ratio_info: Optional[net_helpers.FrameRateReductionRatioinfo] = None,
apply_spec_augment: bool = True,
):

if frame_rate_reduction_ratio_info is None:
frame_rate_reduction_ratio_info = self.frame_rate_reduction_ratio_info
encoder_net = {
"specaug": {
"class": "eval",
"from": "data",
"eval": f"self.network.get_config().typed_value('transform')(source(0, as_data={spec_augment_as_data}), network=self.network)",

if apply_spec_augment:
init_layer = "specaug"
encoder_net = {
init_layer : {
"class": "eval",
"from": "data",
"eval": f"self.network.get_config().typed_value('transform')(source(0, as_data={spec_augment_as_data}), network=self.network)",
}
}
}
from_list = encoder_archs.add_initial_conv(network=encoder_net, linear_size=conf_model_dim, from_list="specaug")
else:
init_layer = "data"
encoder_net = {}
from_list = encoder_archs.add_initial_conv(network=encoder_net, linear_size=conf_model_dim, from_list=init_layer)

encoder_archs.add_conformer_stack(encoder_net, from_list=from_list)
encoder_net[out_layer_name] = {
"class": "copy",
Expand Down Expand Up @@ -1517,6 +1525,7 @@ def get_best_recog_scales_and_transition_values(
use_heuristic_tdp: bool = False,
extend: bool = True,
use_speech_tdp_for_nonword: bool = True,
parallelize_lat2ctm: bool = True,
) -> SearchParameters:

assert self.experiments[key]["decode_job"]["runner"] is not None, "Please set the recognizer"
Expand Down Expand Up @@ -1579,6 +1588,7 @@ def get_best_recog_scales_and_transition_values(
prior_scales=prior_scales,
tdp_scales=tdp_scales,
pron_scales=pron_scales,
parallelize_lat2ctm=parallelize_lat2ctm,
)

if use_heuristic_tdp:
Expand Down Expand Up @@ -1619,6 +1629,7 @@ def get_best_recog_scales_and_transition_values(
altas_beam=16.0,
tdp_sil=nnsp_tdp,
tdp_speech=sp_tdp,
parallelize_lat2ctm=parallelize_lat2ctm,
)

return best_config
198 changes: 10 additions & 188 deletions users/raissi/setups/common/decoder/BASE_factored_hybrid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,7 @@ def recognize_count_lm(
rtf_gpu: float = 4,
lm_config: rasr.RasrConfig = None,
create_lattice: bool = True,
parallelize_lat2ctm: bool = True,
separate_lm_image_gc_generation: bool = False,
search_rqmt_update=None,
adv_search_extra_config: Optional[rasr.RasrConfig] = None,
Expand Down Expand Up @@ -655,6 +656,7 @@ def recognize_count_lm(
adv_search_extra_config=adv_search_extra_config,
adv_search_extra_post_config=adv_search_extra_post_config,
cpu_omp_thread=cpu_omp_thread,
parallelize_lat2ctm=parallelize_lat2ctm,
separate_lm_image_gc_generation=separate_lm_image_gc_generation,
)

Expand Down Expand Up @@ -691,6 +693,7 @@ def recognize(
lm_lookahead_options: Optional = {},
search_rqmt_update=None,
cpu_omp_thread=2,
parallelize_lat2ctm: bool = True,
separate_lm_image_gc_generation: bool = False,
) -> DecodingJobs:
if isinstance(search_parameters, SearchParameters):
Expand Down Expand Up @@ -985,7 +988,7 @@ def recognize(
lat2ctm = recog.LatticeToCtmJob(
crp=search_crp,
lattice_cache=search.out_lattice_bundle,
parallelize=True,
parallelize=parallelize_lat2ctm,
best_path_algo=self.shortest_path_algo.value,
extra_config=lat2ctm_extra_config,
fill_empty_segments=True,
Expand Down Expand Up @@ -1075,6 +1078,7 @@ def recognize(
rerun_after_opt_lm=rerun_after_opt_lm,
search_parameters=params,
use_estimated_tdps=use_estimated_tdps,
parallelize_lat2ctm=parallelize_lat2ctm,
)

return DecodingJobs(
Expand All @@ -1086,183 +1090,6 @@ def recognize(
search_stats=stat,
)

def recognize_optimize_scales(
self,
*,
label_info: LabelInfo,
num_encoder_output: int,
search_parameters: SearchParameters,
prior_scales: Union[
List[Tuple[float]], # center
List[Tuple[float, float]], # center, left
List[Tuple[float, float, float]], # center, left, right
np.ndarray,
],
tdp_scales: Union[List[float], np.ndarray],
tdp_sil: Optional[List[Tuple[TDP, TDP, TDP, TDP]]] = None,
tdp_speech: Optional[List[Tuple[TDP, TDP, TDP, TDP]]] = None,
pron_scales: Union[List[float], np.ndarray] = None,
altas_value=14.0,
altas_beam=14.0,
keep_value=10,
cpu_rqmt: Optional[int] = None,
mem_rqmt: Optional[int] = None,
crp_update: Optional[Callable[[rasr.RasrConfig], Any]] = None,
pre_path: str = "scales",
cpu_slow: bool = True,
) -> SearchParameters:
assert len(prior_scales) > 0
assert len(tdp_scales) > 0

recog_args = dataclasses.replace(search_parameters, altas=altas_value, beam=altas_beam)

if isinstance(prior_scales, np.ndarray):
prior_scales = [(s,) for s in prior_scales] if prior_scales.ndim == 1 else [tuple(s) for s in prior_scales]

prior_scales = [tuple(round(p, 2) for p in priors) for priors in prior_scales]
prior_scales = [
(p, 0.0, 0.0)
if isinstance(p, float)
else (p[0], 0.0, 0.0)
if len(p) == 1
else (p[0], p[1], 0.0)
if len(p) == 2
else p
for p in prior_scales
]
tdp_scales = [round(s, 2) for s in tdp_scales]
tdp_sil = tdp_sil if tdp_sil is not None else [recog_args.tdp_silence]
tdp_speech = tdp_speech if tdp_speech is not None else [recog_args.tdp_speech]
use_pron = self.crp.lexicon_config.normalize_pronunciation and pron_scales is not None

if use_pron:
jobs = {
((c, l, r), tdp, tdp_sl, tdp_sp, pron): self.recognize_count_lm(
add_sis_alias_and_output=False,
calculate_stats=False,
cpu_rqmt=cpu_rqmt,
crp_update=crp_update,
is_min_duration=False,
keep_value=keep_value,
label_info=label_info,
mem_rqmt=mem_rqmt,
name_override=f"{self.name}-pC{c}-pL{l}-pR{r}-tdp{tdp}-tdpSil{tdp_sl}-tdpSp{tdp_sp}-pron{pron}",
num_encoder_output=num_encoder_output,
opt_lm_am=False,
rerun_after_opt_lm=False,
search_parameters=dataclasses.replace(
recog_args, tdp_scale=tdp, tdp_silence=tdp_sl, tdp_speech=tdp_sp, pron_scale=pron
).with_prior_scale(left=l, center=c, right=r, diphone=c),
)
for ((c, l, r), tdp, tdp_sl, tdp_sp, pron) in itertools.product(
prior_scales, tdp_scales, tdp_sil, tdp_speech, pron_scales
)
}

else:
jobs = {
((c, l, r), tdp, tdp_sl, tdp_sp): self.recognize_count_lm(
add_sis_alias_and_output=False,
calculate_stats=False,
cpu_rqmt=cpu_rqmt,
crp_update=crp_update,
is_min_duration=False,
keep_value=keep_value,
label_info=label_info,
mem_rqmt=mem_rqmt,
name_override=f"{self.name}-pC{c}-pL{l}-pR{r}-tdp{tdp}-tdpSil{tdp_sl}-tdpSp{tdp_sp}",
num_encoder_output=num_encoder_output,
opt_lm_am=False,
rerun_after_opt_lm=False,
search_parameters=dataclasses.replace(
recog_args, tdp_scale=tdp, tdp_silence=tdp_sl, tdp_speech=tdp_sp
).with_prior_scale(left=l, center=c, right=r, diphone=c),
)
for ((c, l, r), tdp, tdp_sl, tdp_sp) in itertools.product(prior_scales, tdp_scales, tdp_sil, tdp_speech)
}
jobs_num_e = {k: v.scorer.out_num_errors for k, v in jobs.items()}

if use_pron:
for ((c, l, r), tdp, tdp_sl, tdp_sp, pron), recog_jobs in jobs.items():
if cpu_slow:
recog_jobs.search.update_rqmt("run", {"cpu_slow": True})

pre_name = f"{pre_path}/{self.name}/Lm{recog_args.lm_scale}-Pron{pron}-pC{c}-pL{l}-pR{r}-tdp{tdp}-tdpSil{format_tdp(tdp_sl)}-tdpSp{format_tdp(tdp_sp)}"

recog_jobs.lat2ctm.set_keep_value(keep_value)
recog_jobs.search.set_keep_value(keep_value)

recog_jobs.search.add_alias(pre_name)
tk.register_output(f"{pre_name}.wer", recog_jobs.scorer.out_report_dir)
else:
for ((c, l, r), tdp, tdp_sl, tdp_sp), recog_jobs in jobs.items():
if cpu_slow:
recog_jobs.search.update_rqmt("run", {"cpu_slow": True})

pre_name = f"{pre_path}/{self.name}/Lm{recog_args.lm_scale}-Pron{recog_args.pron_scale}-pC{c}-pL{l}-pR{r}-tdp{tdp}-tdpSil{format_tdp(tdp_sl)}-tdpSp{format_tdp(tdp_sp)}"

recog_jobs.lat2ctm.set_keep_value(keep_value)
recog_jobs.search.set_keep_value(keep_value)

recog_jobs.search.add_alias(pre_name)
tk.register_output(f"{pre_name}.wer", recog_jobs.scorer.out_report_dir)

best_overall_wer = ComputeArgminJob({k: v.scorer.out_wer for k, v in jobs.items()})
best_overall_n = ComputeArgminJob(jobs_num_e)
tk.register_output(
f"decoding/scales-best/{self.name}/args",
best_overall_n.out_argmin,
)
tk.register_output(
f"decoding/scales-best/{self.name}/wer",
best_overall_wer.out_min,
)

def push_delayed_tuple(
argmin: DelayedBase,
) -> Tuple[DelayedBase, DelayedBase, DelayedBase, DelayedBase]:
return tuple(argmin[i] for i in range(4))

# cannot destructure, need to use indices
best_priors = best_overall_n.out_argmin[0]
best_tdp_scale = best_overall_n.out_argmin[1]
best_tdp_sil = best_overall_n.out_argmin[2]
best_tdp_sp = best_overall_n.out_argmin[3]
if use_pron:
best_pron = best_overall_n.out_argmin[4]

base_cfg = dataclasses.replace(
search_parameters,
tdp_scale=best_tdp_scale,
tdp_silence=push_delayed_tuple(best_tdp_sil),
tdp_speech=push_delayed_tuple(best_tdp_sp),
pron_scale=best_pron,
)
else:
base_cfg = dataclasses.replace(
search_parameters,
tdp_scale=best_tdp_scale,
tdp_silence=push_delayed_tuple(best_tdp_sil),
tdp_speech=push_delayed_tuple(best_tdp_sp),
)

best_center_prior = best_priors[0]
if self.context_type.is_monophone():
return base_cfg.with_prior_scale(center=best_center_prior)
if self.context_type.is_joint_diphone():
return base_cfg.with_prior_scale(diphone=best_center_prior)

best_left_prior = best_priors[1]
if self.context_type.is_diphone():
return base_cfg.with_prior_scale(center=best_center_prior, left=best_left_prior)

best_right_prior = best_priors[2]
return base_cfg.with_prior_scale(
center=best_center_prior,
left=best_left_prior,
right=best_right_prior,
)

def recognize_optimize_scales_v2(
self,
*,
Expand All @@ -1286,8 +1113,8 @@ def recognize_optimize_scales_v2(
cpu_rqmt: Optional[int] = None,
mem_rqmt: Optional[int] = None,
crp_update: Optional[Callable[[rasr.RasrConfig], Any]] = None,
parallelize_lat2ctm: bool = True,
pre_path: str = "scales",
cpu_slow: bool = True,
) -> SearchParameters:
assert len(prior_scales) > 0
assert len(tdp_scales) > 0
Expand Down Expand Up @@ -1331,6 +1158,7 @@ def recognize_optimize_scales_v2(
num_encoder_output=num_encoder_output,
opt_lm_am=False,
rerun_after_opt_lm=False,
parallelize_lat2ctm=parallelize_lat2ctm,
search_parameters=dataclasses.replace(
recog_args,
tdp_scale=tdp,
Expand Down Expand Up @@ -1359,6 +1187,7 @@ def recognize_optimize_scales_v2(
num_encoder_output=num_encoder_output,
opt_lm_am=False,
rerun_after_opt_lm=False,
parallelize_lat2ctm=parallelize_lat2ctm,
search_parameters=dataclasses.replace(
recog_args, tdp_scale=tdp, tdp_silence=tdp_sl, tdp_nonword=tdp_nw, tdp_speech=tdp_sp
).with_prior_scale(left=l, center=c, right=r, diphone=c),
Expand All @@ -1371,9 +1200,6 @@ def recognize_optimize_scales_v2(

if use_pron:
for ((c, l, r), tdp, tdp_sl, tdp_nw, tdp_sp, pron), recog_jobs in jobs.items():
if cpu_slow:
recog_jobs.search.update_rqmt("run", {"cpu_slow": True})

pre_name = (
f"{pre_path}/{self.name}/Lm{recog_args.lm_scale}-Pron{pron}-pC{c}-pL{l}-pR{r}-tdp{tdp}-"
f"tdpSil{format_tdp(tdp_sl)}-tdpNw{format_tdp(tdp_nw)}-tdpSp{format_tdp(tdp_sp)}"
Expand All @@ -1386,8 +1212,6 @@ def recognize_optimize_scales_v2(
tk.register_output(f"{pre_name}.wer", recog_jobs.scorer.out_report_dir)
else:
for ((c, l, r), tdp, tdp_sl, tdp_nw, tdp_sp), recog_jobs in jobs.items():
if cpu_slow:
recog_jobs.search.update_rqmt("run", {"cpu_slow": True})

pre_name = (
f"{pre_path}/{self.name}/Lm{recog_args.lm_scale}-Pron{recog_args.pron_scale}"
Expand Down Expand Up @@ -1475,7 +1299,7 @@ def recognize_optimize_transtition_values(
mem_rqmt: Optional[int] = None,
crp_update: Optional[Callable[[rasr.RasrConfig], Any]] = None,
pre_path: str = "transition-values",
cpu_slow: bool = True,
parallelize_lat2ctm: bool = True,
) -> SearchParameters:

recog_args = dataclasses.replace(search_parameters, altas=altas_value, beam=altas_beam)
Expand All @@ -1496,16 +1320,14 @@ def recognize_optimize_transtition_values(
num_encoder_output=num_encoder_output,
opt_lm_am=False,
rerun_after_opt_lm=False,
parallelize_lat2ctm=parallelize_lat2ctm,
search_parameters=dataclasses.replace(recog_args, tdp_silence=tdp_sl, tdp_speech=tdp_sp),
)
for (tdp_sl, tdp_sp) in itertools.product(tdp_sil, tdp_speech)
}
jobs_num_e = {k: v.scorer.out_num_errors for k, v in jobs.items()}

for (tdp_sl, tdp_sp), recog_jobs in jobs.items():
if cpu_slow:
recog_jobs.search.update_rqmt("run", {"cpu_slow": True})

pre_name = f"{pre_path}/{self.name}/" f"tdpSil{format_tdp(tdp_sl)}tdpSp{format_tdp(tdp_sp)}"

recog_jobs.lat2ctm.set_keep_value(keep_value)
Expand Down
3 changes: 2 additions & 1 deletion users/raissi/setups/common/helpers/network/ivectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def add_ivectors_to_conformer_encoder(
num_ivec_input: int,
subsampling_factor: int = 4,
n_ivec_transform: int = 512,
input_layer: str = "specaug"
):
"""
The method follows: https://arxiv.org/pdf/2206.12955
Expand Down Expand Up @@ -83,6 +84,6 @@ def add_ivectors_to_conformer_encoder(
}

network["conformer_1_mhsamod_self_attention"]["from"] = "mhsa_ivec_input"
network["specaug"]["from"] = "source_features"
network[input_layer]["from"] = "source_features"

return network
Loading

0 comments on commit c79d713

Please sign in to comment.