Skip to content

Commit

Permalink
Adding STFT-based SpecAugment for CTC (#261)
Browse files Browse the repository at this point in the history
  • Loading branch information
Max-Ryujin authored Mar 4, 2025
1 parent 058854e commit d7796df
Show file tree
Hide file tree
Showing 4 changed files with 402 additions and 5 deletions.
11 changes: 8 additions & 3 deletions users/vieting/experiments/switchboard/ctc/feat/baseline_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def get_nn_args_single(
if layer_config.get("class") != "variable" and layer_config.get("from", "data") == "data":
feature_net["subnetwork"][layer]["from"] = source_layer


returnn_config = get_returnn_config(
num_inputs=1,
num_outputs=num_outputs,
Expand Down Expand Up @@ -160,6 +161,7 @@ def get_returnn_config(
conformer_type: str = "wei",
specaug_old: Optional[Dict[str, Any]] = None,
specaug_config: Optional[Dict[str, Any]] = None,
specaug_stft: Optional[Dict[str, Any]] = None,
am_args: Optional[Dict[str, Any]] = None,
batch_size: Union[int, Dict[str, int]] = 10000,
sample_rate: int = 8000,
Expand Down Expand Up @@ -218,13 +220,17 @@ def get_returnn_config(
conformer_type=conformer_type,
specaug_old=specaug_old,
specaug_config=specaug_config,
specaug_stft=specaug_stft,
recognition=recognition,
num_epochs=num_epochs,
)
feature_net = copy.deepcopy(feature_net)

if audio_perturbation:
prolog += get_code_for_perturbation()
for layer in list(network.keys()):
if layer in ("stft"):
continue
if network[layer]["from"] == "data":
network[layer]["from"] = "features"
elif isinstance(network[layer]["from"], list) and "data" in network[layer]["from"]:
Expand All @@ -237,14 +243,13 @@ def get_returnn_config(
network.pop(layer)
network["source"] = {"class": "copy", "from": "features"}
else:
# network["source"] = specaug_layer_jingjing(in_layer=["features"])
pass
if specaug_stft is not None:
feature_net["from"] = "istft"

if audio_perturbation and recognition:
# Remove pre-processing from recognition and replace with layers in the network if needed
datasets["train"]["dataset"]["audio"].pop("pre_process", None)

feature_net = copy.deepcopy(feature_net)
audio_perturb_args = extra_args.get("audio_perturb_args", {})
source_layer = "data"
if "preemphasis" in audio_perturb_args:
Expand Down
200 changes: 200 additions & 0 deletions users/vieting/experiments/switchboard/ctc/feat/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,204 @@ def run_mel_audio_perturbation_from_checkpoint():
return report


def run_specaug_stft_experiments():
gs.ALIAS_AND_OUTPUT_SUBDIR = "experiments/switchboard/ctc/feat/"

(
returnn_datasets,
rasr_loss_corpus_path,
rasr_loss_corpus_segments,
rasr_loss_lexicon_path,
dev_corpora,
) = get_datasets()
returnn_args = {
"batch_size": 5000,
"rasr_binary_path": RASR_BINARY_PATH,
"rasr_loss_corpus_path": rasr_loss_corpus_path,
"rasr_loss_corpus_segments": rasr_loss_corpus_segments,
"rasr_loss_lexicon_path": rasr_loss_lexicon_path,
"datasets": returnn_datasets,
"extra_args": {
"accum_grad_multiple_step": 2,
"conv_pad_seq_len_to_power": 1.5,
},
"conformer_type": "wei",
}
feature_args_scf = {"class": "ScfNetwork", "size_tf": 256 // 2, "stride_tf": 10 // 2, "preemphasis": 0.97}
feature_args_lgm = {
"class": "LogMelNetwork",
"wave_norm": True,
"frame_size": 200,
"frame_shift": 80,
"fft_size": 256,
}
lr_args = {
"peak_lr": 4e-4,
"start_lr": 1.325e-05,
"end_lr": 1e-5,
"increase_epochs": 180,
"decrease_epochs": 180,
"final_epochs": 0,
}

nn_args, report_args_collection = get_nn_args_baseline(
nn_base_args={
"bs2x5k_scf_stft20ms_time_only": dict(
returnn_args={
**returnn_args,
"specaug_stft": {
"max_feature": 0,
"max_feature_num": 0,
"frame_size": 400,
"frame_shift": 160,
"fft_size": 512,
},
},
feature_args=feature_args_scf,
lr_args=lr_args,
report_args={"batch_size": "2x5k"},
),
"bs2x5k_scf_stft20ms_fmask_1_1of512": dict(
returnn_args={
**returnn_args,
"specaug_stft": {
"max_feature": 1,
"max_feature_num": 1,
"frame_size": 400,
"frame_shift": 160,
"fft_size": 512,
},
},
feature_args=feature_args_scf,
lr_args=lr_args,
report_args={"batch_size": "2x5k"},
),
"bs2x5k_scf_stft20ms_fmask_2_4of512": dict(
returnn_args={
**returnn_args,
"specaug_stft": {
"max_feature": 4,
"max_feature_num": 2,
"frame_size": 400,
"frame_shift": 160,
"fft_size": 512,
},
},
feature_args=feature_args_scf,
lr_args=lr_args,
report_args={"batch_size": "2x5k"},
),
"bs2x5k_scf_stft20ms_fmask_5_8of512": dict(
returnn_args={
**returnn_args,
"specaug_stft": {"max_feature": 8, "frame_size": 400, "frame_shift": 160, "fft_size": 512},
},
feature_args=feature_args_scf,
lr_args=lr_args,
report_args={"batch_size": "2x5k"},
),
"bs2x5k_scf_stft20ms_fmask_5_15of512": dict(
returnn_args={
**returnn_args,
"specaug_stft": {"max_feature": 15, "frame_size": 400, "frame_shift": 160, "fft_size": 512},
},
feature_args=feature_args_scf,
lr_args=lr_args,
report_args={"batch_size": "2x5k"},
),
"bs2x5k_lgm_stft20ms_fmask_5_8of512": dict(
returnn_args={
**returnn_args,
"specaug_stft": {"max_feature": 8, "frame_size": 400, "frame_shift": 160, "fft_size": 512},
"extra_args": {"accum_grad_multiple_step": 2},
},
feature_args=feature_args_lgm,
lr_args=lr_args,
report_args={
"batch_size": "2x5k",
},
),
"bs10k_lgm_stft20ms_fmask_5_8of512": dict(
returnn_args={
**returnn_args,
"specaug_stft": {"max_feature": 8, "frame_size": 400, "frame_shift": 160, "fft_size": 512},
"batch_size": 10000,
"extra_args": {},
},
feature_args=feature_args_lgm,
lr_args=lr_args,
report_args={"batch_size": "10k"},
),
"bs10k_scf_stft10ms_fmask_5_8of256": dict(
returnn_args={
**returnn_args,
"specaug_stft": {"max_feature": 8},
"batch_size": 10000,
"extra_args": {
"conv_pad_seq_len_to_power": 1.5,
},
},
feature_args=feature_args_scf,
lr_args=lr_args,
report_args={"batch_size": "10k"},
),
"bs10k_lgm_stft10ms_fmask_5_8of256": dict(
returnn_args={
**returnn_args,
"specaug_stft": {"max_feature": 8},
"batch_size": 10000,
"extra_args": {},
},
feature_args=feature_args_lgm,
lr_args=lr_args,
report_args={"batch_size": "10k"},
),
"bs10k_scf_stft10ms_fmask_5_4of256": dict(
returnn_args={
**returnn_args,
"specaug_stft": {"max_feature": 4},
"batch_size": 10000,
"extra_args": {
"conv_pad_seq_len_to_power": 1.5,
},
},
feature_args=feature_args_scf,
lr_args=lr_args,
report_args={"batch_size": "10k"},
),
"bs10k_lgm_stft10ms_fmask_5_4of256": dict(
returnn_args={
**returnn_args,
"specaug_stft": {"max_feature": 4},
"batch_size": 10000,
"extra_args": {},
},
feature_args=feature_args_lgm,
lr_args=lr_args,
report_args={"batch_size": "10k"},
),
},
num_epochs=450,
evaluation_epochs=[24, 350, 390, 400, 410, 420, 430, 440, 450],
prefix="conformer_",
)

returnn_root = CloneGitRepositoryJob(
"https://github.com/rwth-i6/returnn",
commit="c4d36d06f6465e82a50d400d114259e07b8b0709",
).out_repository
returnn_root.hash_overwrite = "returnn_conv_padding"
report, ctc_nn_system = run_nn_args(
nn_args,
report_args_collection,
dev_corpora,
"report_specaug_stft.csv",
returnn_root=returnn_root,
recog_args={"epochs": [24, 350, 390, 400, 410, 420, 430, 440, 450]},
)
return report, ctc_nn_system


def py():
"""
called if the file is passed to sis manager, used to run all experiments (replacement for main)
Expand All @@ -1131,6 +1329,7 @@ def py():
report_scf_specaug_sort = run_scf_specaug_sort()
report_scf_audio_perturbation_from_checkpoint = run_scf_audio_perturbation_from_checkpoint()
report_mel_audio_perturbation_from_checkpoint = run_mel_audio_perturbation_from_checkpoint()
report_specaug_stft = run_specaug_stft_experiments()

report_base = Report(
columns_start=["train_name", "batch_size"],
Expand All @@ -1144,6 +1343,7 @@ def py():
report_scf_specaug_sort,
report_scf_audio_perturbation_from_checkpoint,
report_mel_audio_perturbation_from_checkpoint,
report_specaug_stft,
]
)
tk.register_report(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .network_helpers.specaug import add_specaug_layer, add_specaug_layer_v2
from .network_helpers.specaug_configurable import add_specaug_layer as add_specaug_layer_configurable
from .network_helpers.specaug_sort_layer2 import add_specaug_layer as add_specaug_layer_sort_layer2
from .network_helpers.specaug_stft import add_specaug_layer as add_specaug_layer_stft
from .network_helpers.conformer_wei import add_conformer_stack as add_conformer_stack_wei
from .network_helpers.conformer_wei import add_vgg_stack as add_vgg_stack_wei

Expand Down Expand Up @@ -150,6 +151,7 @@ def make_conformer_fullsum_ctc_model(
conformer_type: str = "wei",
specaug_old: Optional[Dict[str, Any]] = None,
specaug_config: Optional[Dict[str, Any]] = None,
specaug_stft: Optional[Dict[str, Any]] = None,
recognition: bool = False,
num_epochs: Optional[int] = None,
) -> Tuple[Dict, Union[str, List[str]]]:
Expand All @@ -159,7 +161,41 @@ def make_conformer_fullsum_ctc_model(
if recognition:
python_code = []
else:
if specaug_old is not None:
if specaug_stft is not None:
frame_size = specaug_stft.pop("frame_size", 200)
frame_shift = specaug_stft.pop("frame_shift", 80)
fft_size = specaug_stft.pop("fft_size", 256)

specaug_stft_args = {
"max_time_num": 1,
"max_time": 15,
"max_feature_num": 5,
"max_feature": 4,
**specaug_stft,
}

# Add STFT layer
network["stft"] = {
"class": "stft",
"from": ["data"],
"frame_size": frame_size,
"frame_shift": frame_shift,
"fft_size": fft_size,
}
from_list = ["stft"]

from_list, python_code = add_specaug_layer_stft(network, from_list=from_list, **specaug_stft_args)

# Add iSTFT layer
network["istft"] = {
"class": "istft",
"from": from_list,
"frame_size": frame_size,
"frame_shift": frame_shift,
"fft_size": fft_size,
}

elif specaug_old is not None:
assert specaug_config is None
sort_layer2 = specaug_old.pop("sort_layer2", False)
specaug_func = add_specaug_layer_sort_layer2 if sort_layer2 else add_specaug_layer
Expand All @@ -173,7 +209,9 @@ def make_conformer_fullsum_ctc_model(
from_list, python_code = specaug_func(network, from_list=from_list, **specaug_old_args)
elif specaug_config is not None:
assert specaug_old is None
from_list, python_code = add_specaug_layer_configurable(network, from_list=from_list, num_epochs=num_epochs, config=specaug_config)
from_list, python_code = add_specaug_layer_configurable(
network, from_list=from_list, num_epochs=num_epochs, config=specaug_config
)
else:
from_list, python_code = add_specaug_layer_v2(network, from_list=from_list)

Expand Down
Loading

0 comments on commit d7796df

Please sign in to comment.