@@ -903,12 +903,7 @@ def _create_mhsa_module(self, prefix_name, source, layer_index):
903
903
904
904
drop = self .network .add_dropout_layer ("{}_dropout" .format (prefix_name ), mhsa_linear , dropout = self .dropout )
905
905
906
- res_inputs = [drop , source ]
907
-
908
- mhsa_res = self .network .add_combine_layer (
909
- "{}_res" .format (prefix_name ), kind = "add" , source = res_inputs , n_out = self .enc_value_dim
910
- )
911
- return mhsa_res
906
+ return drop
912
907
913
908
def _create_convolution_module (self , prefix_name , source , layer_index , half_step = False ):
914
909
"""
@@ -1071,6 +1066,10 @@ def _create_conformer_block(self, i, source):
1071
1066
if self .convolution_first :
1072
1067
conv_module_ = self ._create_convolution_module (prefix_name , ff_module1 , i )
1073
1068
mhsa_module = self ._create_mhsa_module (prefix_name , conv_module_ , i )
1069
+ mhsa_module = self .network .add_combine_layer (
1070
+ "{}_res" .format (prefix_name ), kind = "add" , source = [mhsa_module , conv_module_ ], n_out = self .enc_value_dim
1071
+ )
1072
+
1074
1073
ff_module2_input = mhsa_module
1075
1074
else :
1076
1075
if self .no_mhsa_module :
@@ -1083,6 +1082,9 @@ def _create_conformer_block(self, i, source):
1083
1082
)
1084
1083
mhsa_input = conv_module1
1085
1084
mhsa = self ._create_mhsa_module (prefix_name , mhsa_input , i )
1085
+ mhsa = self .network .add_combine_layer (
1086
+ "{}_res" .format (prefix_name ), kind = "add" , source = [mhsa , mhsa_input ], n_out = self .enc_value_dim
1087
+ )
1086
1088
1087
1089
conv_module = self ._create_convolution_module (prefix_name , mhsa , i , half_step = self .sandwich_conv )
1088
1090
ff_module2_input = conv_module
0 commit comments