diff --git a/users/vieting/experiments/switchboard/ctc/feat/baseline_args.py b/users/vieting/experiments/switchboard/ctc/feat/baseline_args.py index 1fb03e574..9fe913744 100644 --- a/users/vieting/experiments/switchboard/ctc/feat/baseline_args.py +++ b/users/vieting/experiments/switchboard/ctc/feat/baseline_args.py @@ -114,6 +114,7 @@ def get_nn_args_single( if layer_config.get("class") != "variable" and layer_config.get("from", "data") == "data": feature_net["subnetwork"][layer]["from"] = source_layer + returnn_config = get_returnn_config( num_inputs=1, num_outputs=num_outputs, @@ -160,6 +161,7 @@ def get_returnn_config( conformer_type: str = "wei", specaug_old: Optional[Dict[str, Any]] = None, specaug_config: Optional[Dict[str, Any]] = None, + specaug_stft: Optional[Dict[str, Any]] = None, am_args: Optional[Dict[str, Any]] = None, batch_size: Union[int, Dict[str, int]] = 10000, sample_rate: int = 8000, @@ -218,13 +220,17 @@ def get_returnn_config( conformer_type=conformer_type, specaug_old=specaug_old, specaug_config=specaug_config, + specaug_stft=specaug_stft, recognition=recognition, num_epochs=num_epochs, ) + feature_net = copy.deepcopy(feature_net) if audio_perturbation: prolog += get_code_for_perturbation() for layer in list(network.keys()): + if layer in ("stft"): + continue if network[layer]["from"] == "data": network[layer]["from"] = "features" elif isinstance(network[layer]["from"], list) and "data" in network[layer]["from"]: @@ -237,14 +243,13 @@ def get_returnn_config( network.pop(layer) network["source"] = {"class": "copy", "from": "features"} else: - # network["source"] = specaug_layer_jingjing(in_layer=["features"]) - pass + if specaug_stft is not None: + feature_net["from"] = "istft" if audio_perturbation and recognition: # Remove pre-processing from recognition and replace with layers in the network if needed datasets["train"]["dataset"]["audio"].pop("pre_process", None) - feature_net = copy.deepcopy(feature_net) audio_perturb_args = extra_args.get("audio_perturb_args", {}) source_layer = "data" if "preemphasis" in audio_perturb_args: diff --git a/users/vieting/experiments/switchboard/ctc/feat/experiments.py b/users/vieting/experiments/switchboard/ctc/feat/experiments.py index c0f04a15b..c91c4322e 100644 --- a/users/vieting/experiments/switchboard/ctc/feat/experiments.py +++ b/users/vieting/experiments/switchboard/ctc/feat/experiments.py @@ -1122,6 +1122,204 @@ def run_mel_audio_perturbation_from_checkpoint(): return report +def run_specaug_stft_experiments(): + gs.ALIAS_AND_OUTPUT_SUBDIR = "experiments/switchboard/ctc/feat/" + + ( + returnn_datasets, + rasr_loss_corpus_path, + rasr_loss_corpus_segments, + rasr_loss_lexicon_path, + dev_corpora, + ) = get_datasets() + returnn_args = { + "batch_size": 5000, + "rasr_binary_path": RASR_BINARY_PATH, + "rasr_loss_corpus_path": rasr_loss_corpus_path, + "rasr_loss_corpus_segments": rasr_loss_corpus_segments, + "rasr_loss_lexicon_path": rasr_loss_lexicon_path, + "datasets": returnn_datasets, + "extra_args": { + "accum_grad_multiple_step": 2, + "conv_pad_seq_len_to_power": 1.5, + }, + "conformer_type": "wei", + } + feature_args_scf = {"class": "ScfNetwork", "size_tf": 256 // 2, "stride_tf": 10 // 2, "preemphasis": 0.97} + feature_args_lgm = { + "class": "LogMelNetwork", + "wave_norm": True, + "frame_size": 200, + "frame_shift": 80, + "fft_size": 256, + } + lr_args = { + "peak_lr": 4e-4, + "start_lr": 1.325e-05, + "end_lr": 1e-5, + "increase_epochs": 180, + "decrease_epochs": 180, + "final_epochs": 0, + } + + nn_args, report_args_collection = get_nn_args_baseline( + nn_base_args={ + "bs2x5k_scf_stft20ms_time_only": dict( + returnn_args={ + **returnn_args, + "specaug_stft": { + "max_feature": 0, + "max_feature_num": 0, + "frame_size": 400, + "frame_shift": 160, + "fft_size": 512, + }, + }, + feature_args=feature_args_scf, + lr_args=lr_args, + report_args={"batch_size": "2x5k"}, + ), + "bs2x5k_scf_stft20ms_fmask_1_1of512": dict( + returnn_args={ + **returnn_args, + "specaug_stft": { + "max_feature": 1, + "max_feature_num": 1, + "frame_size": 400, + "frame_shift": 160, + "fft_size": 512, + }, + }, + feature_args=feature_args_scf, + lr_args=lr_args, + report_args={"batch_size": "2x5k"}, + ), + "bs2x5k_scf_stft20ms_fmask_2_4of512": dict( + returnn_args={ + **returnn_args, + "specaug_stft": { + "max_feature": 4, + "max_feature_num": 2, + "frame_size": 400, + "frame_shift": 160, + "fft_size": 512, + }, + }, + feature_args=feature_args_scf, + lr_args=lr_args, + report_args={"batch_size": "2x5k"}, + ), + "bs2x5k_scf_stft20ms_fmask_5_8of512": dict( + returnn_args={ + **returnn_args, + "specaug_stft": {"max_feature": 8, "frame_size": 400, "frame_shift": 160, "fft_size": 512}, + }, + feature_args=feature_args_scf, + lr_args=lr_args, + report_args={"batch_size": "2x5k"}, + ), + "bs2x5k_scf_stft20ms_fmask_5_15of512": dict( + returnn_args={ + **returnn_args, + "specaug_stft": {"max_feature": 15, "frame_size": 400, "frame_shift": 160, "fft_size": 512}, + }, + feature_args=feature_args_scf, + lr_args=lr_args, + report_args={"batch_size": "2x5k"}, + ), + "bs2x5k_lgm_stft20ms_fmask_5_8of512": dict( + returnn_args={ + **returnn_args, + "specaug_stft": {"max_feature": 8, "frame_size": 400, "frame_shift": 160, "fft_size": 512}, + "extra_args": {"accum_grad_multiple_step": 2}, + }, + feature_args=feature_args_lgm, + lr_args=lr_args, + report_args={ + "batch_size": "2x5k", + }, + ), + "bs10k_lgm_stft20ms_fmask_5_8of512": dict( + returnn_args={ + **returnn_args, + "specaug_stft": {"max_feature": 8, "frame_size": 400, "frame_shift": 160, "fft_size": 512}, + "batch_size": 10000, + "extra_args": {}, + }, + feature_args=feature_args_lgm, + lr_args=lr_args, + report_args={"batch_size": "10k"}, + ), + "bs10k_scf_stft10ms_fmask_5_8of256": dict( + returnn_args={ + **returnn_args, + "specaug_stft": {"max_feature": 8}, + "batch_size": 10000, + "extra_args": { + "conv_pad_seq_len_to_power": 1.5, + }, + }, + feature_args=feature_args_scf, + lr_args=lr_args, + report_args={"batch_size": "10k"}, + ), + "bs10k_lgm_stft10ms_fmask_5_8of256": dict( + returnn_args={ + **returnn_args, + "specaug_stft": {"max_feature": 8}, + "batch_size": 10000, + "extra_args": {}, + }, + feature_args=feature_args_lgm, + lr_args=lr_args, + report_args={"batch_size": "10k"}, + ), + "bs10k_scf_stft10ms_fmask_5_4of256": dict( + returnn_args={ + **returnn_args, + "specaug_stft": {"max_feature": 4}, + "batch_size": 10000, + "extra_args": { + "conv_pad_seq_len_to_power": 1.5, + }, + }, + feature_args=feature_args_scf, + lr_args=lr_args, + report_args={"batch_size": "10k"}, + ), + "bs10k_lgm_stft10ms_fmask_5_4of256": dict( + returnn_args={ + **returnn_args, + "specaug_stft": {"max_feature": 4}, + "batch_size": 10000, + "extra_args": {}, + }, + feature_args=feature_args_lgm, + lr_args=lr_args, + report_args={"batch_size": "10k"}, + ), + }, + num_epochs=450, + evaluation_epochs=[24, 350, 390, 400, 410, 420, 430, 440, 450], + prefix="conformer_", + ) + + returnn_root = CloneGitRepositoryJob( + "https://github.com/rwth-i6/returnn", + commit="c4d36d06f6465e82a50d400d114259e07b8b0709", + ).out_repository + returnn_root.hash_overwrite = "returnn_conv_padding" + report, ctc_nn_system = run_nn_args( + nn_args, + report_args_collection, + dev_corpora, + "report_specaug_stft.csv", + returnn_root=returnn_root, + recog_args={"epochs": [24, 350, 390, 400, 410, 420, 430, 440, 450]}, + ) + return report, ctc_nn_system + + def py(): """ called if the file is passed to sis manager, used to run all experiments (replacement for main) @@ -1131,6 +1329,7 @@ def py(): report_scf_specaug_sort = run_scf_specaug_sort() report_scf_audio_perturbation_from_checkpoint = run_scf_audio_perturbation_from_checkpoint() report_mel_audio_perturbation_from_checkpoint = run_mel_audio_perturbation_from_checkpoint() + report_specaug_stft = run_specaug_stft_experiments() report_base = Report( columns_start=["train_name", "batch_size"], @@ -1144,6 +1343,7 @@ def py(): report_scf_specaug_sort, report_scf_audio_perturbation_from_checkpoint, report_mel_audio_perturbation_from_checkpoint, + report_specaug_stft, ] ) tk.register_report( diff --git a/users/vieting/experiments/switchboard/ctc/feat/fullsum_ctc_raw_samples.py b/users/vieting/experiments/switchboard/ctc/feat/fullsum_ctc_raw_samples.py index a81890005..21ac2e254 100644 --- a/users/vieting/experiments/switchboard/ctc/feat/fullsum_ctc_raw_samples.py +++ b/users/vieting/experiments/switchboard/ctc/feat/fullsum_ctc_raw_samples.py @@ -9,6 +9,7 @@ from .network_helpers.specaug import add_specaug_layer, add_specaug_layer_v2 from .network_helpers.specaug_configurable import add_specaug_layer as add_specaug_layer_configurable from .network_helpers.specaug_sort_layer2 import add_specaug_layer as add_specaug_layer_sort_layer2 +from .network_helpers.specaug_stft import add_specaug_layer as add_specaug_layer_stft from .network_helpers.conformer_wei import add_conformer_stack as add_conformer_stack_wei from .network_helpers.conformer_wei import add_vgg_stack as add_vgg_stack_wei @@ -150,6 +151,7 @@ def make_conformer_fullsum_ctc_model( conformer_type: str = "wei", specaug_old: Optional[Dict[str, Any]] = None, specaug_config: Optional[Dict[str, Any]] = None, + specaug_stft: Optional[Dict[str, Any]] = None, recognition: bool = False, num_epochs: Optional[int] = None, ) -> Tuple[Dict, Union[str, List[str]]]: @@ -159,7 +161,41 @@ def make_conformer_fullsum_ctc_model( if recognition: python_code = [] else: - if specaug_old is not None: + if specaug_stft is not None: + frame_size = specaug_stft.pop("frame_size", 200) + frame_shift = specaug_stft.pop("frame_shift", 80) + fft_size = specaug_stft.pop("fft_size", 256) + + specaug_stft_args = { + "max_time_num": 1, + "max_time": 15, + "max_feature_num": 5, + "max_feature": 4, + **specaug_stft, + } + + # Add STFT layer + network["stft"] = { + "class": "stft", + "from": ["data"], + "frame_size": frame_size, + "frame_shift": frame_shift, + "fft_size": fft_size, + } + from_list = ["stft"] + + from_list, python_code = add_specaug_layer_stft(network, from_list=from_list, **specaug_stft_args) + + # Add iSTFT layer + network["istft"] = { + "class": "istft", + "from": from_list, + "frame_size": frame_size, + "frame_shift": frame_shift, + "fft_size": fft_size, + } + + elif specaug_old is not None: assert specaug_config is None sort_layer2 = specaug_old.pop("sort_layer2", False) specaug_func = add_specaug_layer_sort_layer2 if sort_layer2 else add_specaug_layer @@ -173,7 +209,9 @@ def make_conformer_fullsum_ctc_model( from_list, python_code = specaug_func(network, from_list=from_list, **specaug_old_args) elif specaug_config is not None: assert specaug_old is None - from_list, python_code = add_specaug_layer_configurable(network, from_list=from_list, num_epochs=num_epochs, config=specaug_config) + from_list, python_code = add_specaug_layer_configurable( + network, from_list=from_list, num_epochs=num_epochs, config=specaug_config + ) else: from_list, python_code = add_specaug_layer_v2(network, from_list=from_list) diff --git a/users/vieting/experiments/switchboard/ctc/feat/network_helpers/specaug_stft.py b/users/vieting/experiments/switchboard/ctc/feat/network_helpers/specaug_stft.py new file mode 100644 index 000000000..915b48537 --- /dev/null +++ b/users/vieting/experiments/switchboard/ctc/feat/network_helpers/specaug_stft.py @@ -0,0 +1,154 @@ +from typing import Dict, Union, Optional, List +from i6_core.returnn.config import CodeWrapper +from .rc_specaug import _mask_v1, random_mask_v1, specaugment_v1_eval_func + + +def add_specaug_layer( + network: Dict, + name: str = "specaug", + from_list: Optional[Union[str, List[str]]] = None, + max_time_num: int = 3, + max_time: int = 10, + max_feature_num: int = 4, + max_feature: int = 5, +) -> List[str]: + if from_list is None: + from_list = ["data"] + network[name] = { + "class": "eval", + "from": from_list, + "eval": f'self.network.get_config().typed_value("transform")(source(0, as_data=True), max_time_num={max_time_num}, max_time={max_time}, max_feature_num={max_feature_num}, max_feature={max_feature}, network=self.network)', + } + + return [name], get_specaug_funcs() + + +def _mask(x, batch_axis, axis, pos, max_amount): + """ + :param tf.Tensor x: (batch,time,feature) + :param int batch_axis: + :param int axis: + :param tf.Tensor pos: (batch,) + :param int|tf.Tensor max_amount: inclusive + """ + import tensorflow as tf + + ndim = x.get_shape().ndims + n_batch = tf.shape(x)[batch_axis] + dim = tf.shape(x)[axis] + amount = tf.random.uniform(shape=(n_batch,), minval=1, maxval=max_amount + 1, dtype=tf.int32) + pos2 = tf.math.minimum(pos + amount, dim) + idxs = tf.expand_dims(tf.range(0, dim), 0) # (1,dim) + pos_bc = tf.expand_dims(pos, 1) # (batch,1) + pos2_bc = tf.expand_dims(pos2, 1) # (batch,1) + cond = tf.math.logical_and(tf.greater_equal(idxs, pos_bc), tf.less(idxs, pos2_bc)) # (batch,dim) + if batch_axis > axis: + cond = tf.transpose(cond) # (dim,batch) + cond = tf.reshape(cond, [tf.shape(x)[i] if i in (batch_axis, axis) else 1 for i in range(ndim)]) + from TFUtil import where_bc + + x = where_bc(cond, tf.constant(0.0, dtype=x.dtype), x) + return x + + +def random_mask(x, batch_axis, axis, min_num, max_num, max_dims): + """ + :param tf.Tensor x: (batch,time,feature) + :param int batch_axis: + :param int axis: + :param int|tf.Tensor min_num: + :param int|tf.Tensor max_num: inclusive + :param int|tf.Tensor max_dims: inclusive + """ + import tensorflow as tf + + n_batch = tf.shape(x)[batch_axis] + if isinstance(min_num, int) and isinstance(max_num, int) and min_num == max_num: + num = min_num + else: + num = tf.random.uniform(shape=(n_batch,), minval=min_num, maxval=max_num + 1, dtype=tf.int32) + # https://github.com/tensorflow/tensorflow/issues/9260 + # https://timvieira.github.io/blog/post/2014/08/01/gumbel-max-trick-and-weighted-reservoir-sampling/ + z = -tf.math.log(-tf.math.log(tf.random.uniform((n_batch, tf.shape(x)[axis]), 0, 1))) + _, indices = tf.math.top_k(z, num if isinstance(num, int) else tf.reduce_max(num)) + # indices should be sorted, and of shape (batch,num), entries (int32) in [0,dim) + # indices = tf.Print(indices, ["indices", indices, tf.shape(indices)]) + if isinstance(num, int): + for i in range(num): + x = _mask( + x, + batch_axis=batch_axis, + axis=axis, + pos=indices[:, i], + max_amount=max_dims, + ) + else: + _, x = tf.while_loop( + cond=lambda i, _: tf.less(i, tf.reduce_max(num)), + body=lambda i, x: ( + i + 1, + tf.where( + tf.expand_dims(tf.expand_dims(tf.less(i, num), axis=-1), axis=-1), + _mask( + x, + batch_axis=batch_axis, + axis=axis, + pos=indices[:, i], + max_amount=max_dims, + ), + x, + ), + ), + loop_vars=(0, x), + ) + return x + + +def transform(data, max_time_num, max_time, max_feature_num, max_feature, network): + # halved before this step + conservative_step = 2000 + + x = data.placeholder + import tensorflow as tf + + step = network.global_train_step + increase_flag = tf.where(tf.greater_equal(step, conservative_step), 0, 1) + + def get_masked(): + x_masked = x + x_masked = random_mask( + x_masked, + batch_axis=data.batch_dim_axis, + axis=data.time_dim_axis, + min_num=0, + max_num=tf.maximum( + tf.shape(x)[data.time_dim_axis] // int(1.0 / 0.7 * max_time), + max_time_num, + ) + // (1 + increase_flag), + max_dims=max_time, + ) + x_masked = random_mask( + x_masked, + batch_axis=data.batch_dim_axis, + axis=data.feature_dim_axis, + min_num=0, + max_num=max_feature_num // (1 + increase_flag), + max_dims=max_feature, + ) + return x_masked + + x = network.cond_on_train(get_masked, lambda: x) + return x + + +def get_specaug_funcs() -> list: + return [_mask, random_mask, transform] + + +def get_specaug_func_v2() -> list: + return [ + _mask_v1, + random_mask_v1, + specaugment_v1_eval_func, + ]