Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
robin-p-schmitt committed Mar 6, 2025
1 parent c79d713 commit f4aa1f2
Show file tree
Hide file tree
Showing 10 changed files with 164 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def forward_sequence(
config = get_global_config()

if type(model) is TransformerDecoder:
logits, _, _ = model(
logits, _ = model(
rf.shift_right(targets, axis=targets_spatial_dim, pad_value=0),
spatial_dim=targets_spatial_dim,
encoder=model.transform_encoder(att_enc_args["enc"], axis=enc_spatial_dim),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,17 @@ def run_exps():
corpus_keys=("test-other",)
)

if alias == "v1_long_two-stage":
checkpoint_aliases = ("best-4-avg",)

recog.center_window_returnn_frame_wise_beam_search(
alias=full_sum_train_alias,
config_builder=config_builder,
checkpoint=full_sum_checkpoint,
checkpoint_aliases=checkpoint_aliases,
beam_size_list=(100,)
)

if use_sep_h_t_readout and "long" in alias:
for att_readout_scale, h_t_readout_scale in [
(1.0, 0.1),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def run_exps():
checkpoint=fixed_path_checkpoint,
checkpoint_aliases=checkpoint_aliases,
)

if win_size is None and "long" in alias:
pipeline = recog.center_window_returnn_frame_wise_beam_search(
alias=fixed_path_train_alias,
Expand Down Expand Up @@ -269,6 +270,17 @@ def run_exps():
checkpoint=full_sum_checkpoint,
)

if alias == "v1_long_two-stage":
checkpoint_aliases = ("best-4-avg",)

recog.center_window_returnn_frame_wise_beam_search(
alias=full_sum_train_alias,
config_builder=config_builder,
checkpoint=full_sum_checkpoint,
checkpoint_aliases=checkpoint_aliases,
beam_size_list=(100,)
)

if alias == "v1_long_two-stage" and gpu_mem_rqmt == 24:
recog.center_window_returnn_frame_wise_beam_search(
alias=full_sum_train_alias,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,34 @@
def run_exps():
for model_alias, config_builder in baseline.global_att_baseline_rf(use_weight_feedback=True):
# v5: same as v3, but use bpe size 10k
for random_seed in [None, 1234]:
for random_seed, gpu_mem_rqmt, ctc_aux_loss_layers in [
[None, 11, None],
[1234, 11, None],
[None, 24, (4, 8)],
]:

if gpu_mem_rqmt == 24:
use_mgpu = False
accum_grad_multiple_step = 2
batch_size = 35_000
n_epochs = 2_000
else:
use_mgpu = True
accum_grad_multiple_step = 4
batch_size = 15_000
n_epochs = 500

for train_alias, checkpoint in train.train_global_att(
alias=model_alias,
config_builder=config_builder,
n_epochs=500,
keep_epochs=[10, 20, 30] + list(range(30, 50, 1)),
n_epochs=n_epochs,
filter_data_len=19.5 * 16_000, # sample rate 16kHz
random_seed=random_seed,
use_mgpu=use_mgpu,
accum_grad_multiple_step=accum_grad_multiple_step,
ctc_aux_loss_layers=ctc_aux_loss_layers,
batch_size=batch_size,
gpu_mem_rqmt=gpu_mem_rqmt,
):
if random_seed != 1234:
recog.global_att_returnn_label_sync_beam_search(
Expand Down Expand Up @@ -155,6 +175,14 @@ def run_exps():
run_analysis=True,
analyze_gradients=True,
)
recog.global_att_returnn_label_sync_beam_search(
alias=train_alias,
config_builder=config_builder,
checkpoint=checkpoint,
checkpoint_aliases=("last",),
beam_size_list=(100,),
corpus_keys=("test-other", "dev-other"),
)
recog.global_att_returnn_label_sync_beam_search(
alias=train_alias,
config_builder=config_builder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,16 @@ def run_exps():
behavior_version=21,
)

if alias in ("v5_big", "v3_big"):
recog.global_att_returnn_label_sync_beam_search(
alias=train_alias,
config_builder=config_builder,
checkpoint=checkpoint,
corpus_keys=("dev-other", "test-other"),
checkpoint_aliases=("last",),
beam_size_list=(100,)
)

if alias in (
# "v5_big",
"v8_big",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,21 @@ def run_exps():
["v3_wo-wf-w-ctx-in-state_ctc", None, None, (4, 8), False, 12, 512, False, False, False, list(range(1, 240)), 24, None, None, False, False, False, False, False, 1e-6, None],
["v3_wo-wf-wo-ctx-in-state_ctc", None, None, (4, 8), False, 12, 512, False, False, False, list(range(1, 240)), 24, None, None, False, False, False, False, False, 1e-6, None],
["v3_w-wf-w-ctx-in-state_ctc", 9999, None, (4, 8), False, 12, 512, False, False, False, list(range(1, 240)), 24, None, None, False, False, False, False, False, 1e-6, None],
["v3_rand-9999", 9999, None, None, False, 12, 512, False, False, False, list(range(1, 240)), 11, None, None, False, False, False, False, False, 1e-6, None], # v3_big_rand - flipped
["v3_rand-1234", 1234, None, None, False, 12, 512, False, False, False, list(range(1, 240)), 11, None, None, False, False, False, False, False, 1e-6, None], # v3_big_rand -
["v3_rand-1111", 1111, None, None, False, 12, 512, False, False, False, list(range(1, 240)), 11, None, None, False, False, False, False, False, 1e-6, None], # v3_big_rand
["v3_rand-4321", 4321, None, None, False, 12, 512, False, False, False, list(range(1, 240)), 11, None, None, False, False, False, False, False, 1e-6, None], # v3_big_rand
["v3_rand-5678", 5678, None, None, False, 12, 512, False, False, False, list(range(10, 80, 10)), 11, None, None, False, False, False, False, False, 1e-6, None], # v3_big_rand
# ["v3_rand-9999", 9999, None, None, False, 12, 512, False, False, False, list(range(1, 240)), 11, None, None, False, False, False, False, False, 1e-6, None], # v3_big_rand - flipped
# ["v3_rand-1234", 1234, None, None, False, 12, 512, False, False, False, list(range(1, 240)), 11, None, None, False, False, False, False, False, 1e-6, None], # v3_big_rand -
# ["v3_rand-1111", 1111, None, None, False, 12, 512, False, False, False, list(range(1, 240)), 11, None, None, False, False, False, False, False, 1e-6, None], # v3_big_rand
# ["v3_rand-4321", 4321, None, None, False, 12, 512, False, False, False, list(range(1, 240)), 11, None, None, False, False, False, False, False, 1e-6, None], # v3_big_rand
# ["v3_rand-5678", 5678, None, None, False, 12, 512, False, False, False, list(range(10, 80, 10)), 11, None, None, False, False, False, False, False, 1e-6, None], # v3_big_rand
["v3_big_rand-5678", 5678, None, None, False, 12, 512, False, False, False, list(range(20, 200, 20)), 24, None, None, False, False, False, False, False, 1e-6, None], # v3_big_rand
["v3_rand-8765", 8765, None, None, False, 12, 512, False, False, False, list(range(10, 80, 10)), 11, None, None, False, False, False, False, False, 1e-6, None], # v3_big_rand
["v3_rand-2222", 2222, None, None, False, 12, 512, False, False, False, list(range(10, 80, 10)), 11, None, None, False, False, False, False, False, 1e-6, None], # v3_big_rand
# ["v3_rand-8765", 8765, None, None, False, 12, 512, False, False, False, list(range(10, 80, 10)), 11, None, None, False, False, False, False, False, 1e-6, None], # v3_big_rand
# ["v3_rand-2222", 2222, None, None, False, 12, 512, False, False, False, list(range(10, 80, 10)), 11, None, None, False, False, False, False, False, 1e-6, None], # v3_big_rand
["v3_big_rand-2222", 2222, None, None, False, 12, 512, False, False, False, list(range(20, 200, 20)), 24, None, None, False, False, False, False, False, 1e-6, None], # v3_big_rand
["v3_rand-3333", 3333, None, None, False, 12, 512, False, False, False, list(range(10, 80, 10)), 11, None, None, False, False, False, False, False, 1e-6, None], # v3_big_rand
# ["v3_rand-3333", 3333, None, None, False, 12, 512, False, False, False, list(range(10, 80, 10)), 11, None, None, False, False, False, False, False, 1e-6, None], # v3_big_rand
["v5", None, 21, None, False, 12, 512, False, False, False, list(range(61)), 11, None, None, False, False, False, False, False, 1e-6, None], # v5_big: same as v3_big, but enable self attention only after 20 sub-epochs (1 full epoch)
["v5_rand-1234", 1234, 21, None, False, 12, 512, False, False, False, list(range(61)), 11, None, None, False, False, False, False, False, 1e-6, None], # v5_big: same as v3_big, but enable self attention only after 20 sub-epochs (1 full epoch)
# ["v5_rand-1234", 1234, 21, None, False, 12, 512, False, False, False, list(range(61)), 11, None, None, False, False, False, False, False, 1e-6, None], # v5_big: same as v3_big, but enable self attention only after 20 sub-epochs (1 full epoch)
["v5_big_rand-1234", 1234, 21, None, False, 12, 512, False, False, False, list(range(20, 200, 20)), 24, None, None, False, False, False, False, False, 1e-6, None], # v5_big: same as v3_big, but enable self attention only after 20 sub-epochs (1 full epoch)
# ["v5_big_rand-2222", 2222, 21, None, False, 12, 512, False, False, False, [2000], 24, None, None, False, False, False, False, False, 1e-6, None], # v5_big: same as v3_big, but enable self attention only after 20 sub-epochs (1 full epoch)
# ["v5_big_rand-3333", 3333, 21, None, False, 12, 512, False, False, False, [2000], 24, None, None, False, False, False, False, False, 1e-6, None], # v5_big: same as v3_big, but enable self attention only after 20 sub-epochs (1 full epoch)
["v6_big", None, None, None, False, 12, 512, False, False, True, list(range(1, 240)), 24, None, None, False, False, False, False, False, 1e-6, None], # v6_big: same as v3_big, but use both absolute and relative positional encodings
["v6", None, None, None, False, 12, 512, False, False, True, list(range(1, 240)), 11, None, None, False, False, False, False, False, 1e-6, None], # v6_big: same as v3_big, but use both absolute and relative positional encodings
# ["v7_big", None, None, None, True, 12, 512, False, False, False, [121, 131, 141], 24, None, None, False], # v7_big: same as v3_big, but do not use final layer norm in conformer encoder layers
Expand Down Expand Up @@ -235,6 +237,16 @@ def run_exps():
sbatch_args=["-p", "gpu_48gb,gpu_24gb_preemptive"],
)

if alias in ("v5_big_rand-1234", "v20"):
recog.global_att_returnn_label_sync_beam_search(
alias=train_alias,
config_builder=config_builder,
checkpoint=checkpoint,
corpus_keys=("dev-other", "test-other"),
checkpoint_aliases=("last",),
beam_size_list=(100,)
)

# test different const LRs for mini LSTM
if alias in ["v5_big_rand-1234"]:
for lm_scale, ilm_scale in [
Expand Down Expand Up @@ -440,6 +452,7 @@ def run_exps():

plot_flipped_cross_att_weight_evolution_v2(flipped_att_weights_evolution_epochs, flipped_att_weights_evolution)
# plot_flipped_self_att_weight_evolution()
plot_flipped_self_att_energies_single_epochs()
# plot_flipped_vs_normal_cross_att_weights()
# plot_gradients_wrt_different_layers()

Expand Down Expand Up @@ -623,14 +636,31 @@ def plot_gradients_wrt_different_layers():


def plot_flipped_self_att_weight_evolution():
epochs = [10, 20, 30, 32, 34, 38, 40, 50]
epochs = [10, 40, 50]
for head in range(8, 9):
plot_self_att_weights_job = PlotSelfAttentionWeightsOverEpochsJob(
att_weight_hdfs=[
Path(
f"/u/schmitt/experiments/03-09-24_aed_flipped_encoder/alias/models/ls_conformer/global_att/baseline_v1/baseline_rf/bpe1056/w-weight-feedback/w-att-ctx-in-state/nb-lstm/12-layer_512-dim_conformer-w-abs-pos/train_from_scratch/500-ep_bs-15000_mgpu-4_w-sp_curric_lr-dyn_lr_piecewise_linear_epoch-wise_v2_reg-v1_filter-data-312000.0_accum-4/returnn_decoding/epoch-{epoch}-checkpoint/no-lm/beam-size-12/train/analysis/dump_self_att/ground-truth/output/self-att-energies_head-{head}.hdf") for epoch in epochs
f"/u/schmitt/experiments/03-09-24_aed_flipped_encoder/alias/models/ls_conformer/global_att/baseline_v1/baseline_rf/bpe1056/w-weight-feedback/w-att-ctx-in-state/nb-lstm/12-layer_512-dim_conformer-w-abs-pos/train_from_scratch/500-ep_bs-15000_mgpu-4_w-sp_curric_wd-1e-06_reg-v1_filter-data-312000.0_accum-4/lr_dyn_lr_piecewise_linear_epoch-wise_v2_peak_lr-0.001_init_lr-1e-05/returnn_decoding/epoch-{epoch}-checkpoint/scale-1.00_len-norm-exp-1.0/no-lm/wo_ilm_correction/beam-size-12/train/analysis/dump_self_att/ground-truth/output/self-att-energies_head-{head}.hdf") for epoch in epochs
],
epochs=epochs,
)
plot_self_att_weights_job.add_alias(f"flipped_self_att_evolution_head-{head}")
tk.register_output(plot_self_att_weights_job.get_one_alias(), plot_self_att_weights_job.out_plot_dir)


def plot_flipped_self_att_energies_single_epochs():
epochs = [
10, 40,
50]
for epoch in epochs:
for head in range(8, 9):
plot_self_att_weights_job = PlotSelfAttentionWeightsOverEpochsJob(
att_weight_hdfs=[
Path(
f"/u/schmitt/experiments/03-09-24_aed_flipped_encoder/alias/models/ls_conformer/global_att/baseline_v1/baseline_rf/bpe1056/w-weight-feedback/w-att-ctx-in-state/nb-lstm/12-layer_512-dim_conformer-w-abs-pos/train_from_scratch/500-ep_bs-15000_mgpu-4_w-sp_curric_wd-1e-06_reg-v1_filter-data-312000.0_accum-4/lr_dyn_lr_piecewise_linear_epoch-wise_v2_peak_lr-0.001_init_lr-1e-05/returnn_decoding/epoch-{epoch}-checkpoint/scale-1.00_len-norm-exp-1.0/no-lm/wo_ilm_correction/beam-size-12/train/analysis/dump_self_att/ground-truth/output/self-att-energies_head-{head}.hdf")
],
epochs=[epoch],
)
plot_self_att_weights_job.add_alias(f"flipped_self_att_evolution_head-{head}_epoch-{epoch}")
tk.register_output(plot_self_att_weights_job.get_one_alias(), plot_self_att_weights_job.out_plot_dir)
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ def run_exps():
analyze_gradients=True,
run_analysis=True,
)
# recog.global_att_returnn_label_sync_beam_search(
# alias=train_alias,
# config_builder=config_builder,
# checkpoint=checkpoint,
# checkpoint_aliases=("epoch-498",),
# beam_size_list=(100,),
# )
for corpus_key in [
# "dev-other_0.1-5.1",
# "dev-other_5.1-10.1",
Expand Down Expand Up @@ -55,6 +62,26 @@ def run_exps():
]
)




for model_alias, config_builder in baseline.global_att_baseline_rf(
use_weight_feedback=True,
label_type="bpe10025"
):
for train_alias, checkpoint in train.train_global_att(
alias=model_alias,
config_builder=config_builder,
n_epochs=2_000,
batch_size=35_000,
gpu_mem_rqmt=24,
accum_grad_multiple_step=2,
use_mgpu=False,
use_torch_amp=False,
filter_data_len=19.5 * 16_000,
random_seed=None,
ctc_aux_loss_layers=(4, 8),
):
recog.global_att_returnn_label_sync_beam_search(
alias=train_alias,
config_builder=config_builder,
checkpoint=checkpoint,
beam_size_list=(12, 100,),
)
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
with_pos_bias=False,
learnable_pos_emb=True,
separate_pos_emb_per_head=False,
pos_emb_dropout=0.0,
pos_emb_dropout=0.0, # 0.1
),
self_att=rf.build_dict(rf.RelPosSelfAttention),
ff_activation=rf.build_dict(rf.relu_square),
Expand All @@ -64,7 +64,7 @@
gradient_clip_global_norm=5.0,
optimizer={
"class": "adamw",
"epsilon": 1e-8,
"epsilon": 1e-8, # 1e-16
"weight_decay": 1e-6,
"weight_decay_modules_blacklist": [
"rf.Embedding",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,10 +257,8 @@ def py():
recog_exp = RecogExperiment(
alias=config["alias"],
config_builder=config_builder,
# checkpoint=recog_checkpoints["best-wer"],
# checkpoint_alias="best-wer",
checkpoint=checkpoints[518],
checkpoint_alias=f"epoch-518",
checkpoint=recog_checkpoints["best-wer"],
checkpoint_alias="best-wer",
recog_opts=recog_opts,
search_rqmt=dict(cpu=4)
)
Expand Down
34 changes: 24 additions & 10 deletions users/schmitt/visualization/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,9 +819,9 @@ def run(self):
num_rows = 1
num_cols = num_plots

title_fontsize = 38
ticklabel_fontsize = 28
axlabel_fontsize = 34
title_fontsize = 30 # 38
ticklabel_fontsize = 24 # 28
axlabel_fontsize = 26 # 34

for seq_tag in att_weights_dicts[0]:
# fig, axes = plt.subplots(
Expand All @@ -837,16 +837,22 @@ def run(self):
epoch = self.epochs[i]
time_len = att_weights.shape[0]

ax = axes[row, col]
if num_rows == 2:
ax = axes[row, col]
elif num_plots == 1:
ax = axes
else:
ax = axes[col]
ax.matshow(att_weights, cmap=plt.cm.get_cmap("Blues"), aspect="auto")

ax.set_title(f"Epoch {epoch * 4 / 20}", fontsize=title_fontsize, pad=7)
if num_plots != 1:
ax.set_title(f"Epoch {epoch * 4 / 20}", fontsize=title_fontsize, pad=7)

time_step_size = 1 / 60 * 1000
time_ticks = np.arange(0, time_len, time_step_size)
tick_labels = [(time_tick * 60) / 1000 for time_tick in time_ticks]

if row == 0:
if row == 0 and num_rows == 2:
ax.set_xticks([])
else:
ax.set_xticks(time_ticks)
Expand All @@ -863,12 +869,20 @@ def run(self):

ax.invert_yaxis()

fig.text(0.5, 0.01, 'Keys/Values time (s)', ha='center', fontsize=axlabel_fontsize)
fig.text(0.06, 0.5, 'Queries time (s)', va='center', rotation='vertical', fontsize=axlabel_fontsize)
# fig.text(0.5, 0.01, 'Keys/Values time (s)', ha='center', fontsize=axlabel_fontsize)
# fig.text(0.06, 0.5, 'Queries time (s)', va='center', rotation='vertical', fontsize=axlabel_fontsize)
fig.text(0.5, -0.15, 'Keys/Values time (s)', ha='center', fontsize=axlabel_fontsize)
fig.text(-0.15, 0.5, 'Queries time (s)', va='center', rotation='vertical', fontsize=axlabel_fontsize)
# fig.tight_layout()

plt.savefig(os.path.join(self.out_plot_dir.get_path(), f"plot.{seq_tag.replace('/', '_')}.png"), bbox_inches='tight')
plt.savefig(os.path.join(self.out_plot_dir.get_path(), f"plot.{seq_tag.replace('/', '_')}.pdf"), bbox_inches='tight')
plt.savefig(
os.path.join(self.out_plot_dir.get_path(), f"plot.{seq_tag.replace('/', '_')}.png"),
bbox_inches='tight'
)
plt.savefig(
os.path.join(self.out_plot_dir.get_path(), f"plot.{seq_tag.replace('/', '_')}.pdf"),
bbox_inches='tight'
)
plt.close()


Expand Down

0 comments on commit f4aa1f2

Please sign in to comment.