Skip to content

Commit b930a71

Browse files
committed
decouple mhsa residual
1 parent 6987cd7 commit b930a71

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

users/zeineldeen/models/asr/encoder/conformer_encoder_v2.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -903,12 +903,7 @@ def _create_mhsa_module(self, prefix_name, source, layer_index):
903903

904904
drop = self.network.add_dropout_layer("{}_dropout".format(prefix_name), mhsa_linear, dropout=self.dropout)
905905

906-
res_inputs = [drop, source]
907-
908-
mhsa_res = self.network.add_combine_layer(
909-
"{}_res".format(prefix_name), kind="add", source=res_inputs, n_out=self.enc_value_dim
910-
)
911-
return mhsa_res
906+
return drop
912907

913908
def _create_convolution_module(self, prefix_name, source, layer_index, half_step=False):
914909
"""
@@ -1071,6 +1066,10 @@ def _create_conformer_block(self, i, source):
10711066
if self.convolution_first:
10721067
conv_module_ = self._create_convolution_module(prefix_name, ff_module1, i)
10731068
mhsa_module = self._create_mhsa_module(prefix_name, conv_module_, i)
1069+
mhsa_module = self.network.add_combine_layer(
1070+
"{}_res".format(prefix_name), kind="add", source=[mhsa_module, conv_module_], n_out=self.enc_value_dim
1071+
)
1072+
10741073
ff_module2_input = mhsa_module
10751074
else:
10761075
if self.no_mhsa_module:
@@ -1083,6 +1082,9 @@ def _create_conformer_block(self, i, source):
10831082
)
10841083
mhsa_input = conv_module1
10851084
mhsa = self._create_mhsa_module(prefix_name, mhsa_input, i)
1085+
mhsa = self.network.add_combine_layer(
1086+
"{}_res".format(prefix_name), kind="add", source=[mhsa, mhsa_input], n_out=self.enc_value_dim
1087+
)
10861088

10871089
conv_module = self._create_convolution_module(prefix_name, mhsa, i, half_step=self.sandwich_conv)
10881090
ff_module2_input = conv_module

users/zeineldeen/models/asr/encoder/ebranchformer_encoder.py

+5
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ def __init__(self, cgmlp_ff_dim, **kwargs):
1414
self.cgmlp_ff_dim = cgmlp_ff_dim
1515

1616
def _create_conv_spatial_gating_unit(self, prefix_name, source, layer_index):
17+
# Half split input into [A,B] -> A * DwConv(LN(B)) -> dropout
18+
#
1719
# see also here: https://github.com/espnet/espnet/blob/master/espnet2/asr/layers/cgmlp.py#L15
1820

1921
split_size = self.cgmlp_ff_dim // 2
@@ -47,6 +49,8 @@ def _create_conv_spatial_gating_unit(self, prefix_name, source, layer_index):
4749
return dropout
4850

4951
def _create_conv_gating_mlp(self, prefix_name, source, layer_index):
52+
# GeLU(FF(LN(x))) -> Half split input into [A,B] -> A * DwConv(LN(B)) -> dropout -> FF
53+
5054
prefix_name = "{}_cgmlp".format(prefix_name)
5155

5256
ln = self.network.add_layer_norm_layer("{}_ln".format(prefix_name), source)
@@ -65,6 +69,7 @@ def _create_conv_gating_mlp(self, prefix_name, source, layer_index):
6569

6670
gelu_act = self.network.add_activation_layer("{}_gelu".format(prefix_name), ff1, activation="gelu")
6771

72+
# Half split input into [A,B] -> A * DwConv(LN(B)) -> dropout
6873
csgu = self._create_conv_spatial_gating_unit(f"{prefix_name}_csgu", gelu_act, layer_index)
6974

7075
br_merge_ff = self.network.add_linear_layer(

0 commit comments

Comments
 (0)