diff --git a/wenet/paraformer/paraformer.py b/wenet/paraformer/paraformer.py index fa14797e9..782422564 100644 --- a/wenet/paraformer/paraformer.py +++ b/wenet/paraformer/paraformer.py @@ -97,7 +97,8 @@ def forward(self, tp_alphas = tp_alphas.squeeze(-1) tp_token_num = tp_alphas.sum(-1) - return acoustic_embeds, token_num, alphas, cif_peak, tp_alphas, tp_token_num, mask + return acoustic_embeds, token_num, alphas, cif_peak, tp_alphas, \ + tp_token_num, mask class Paraformer(ASRModel):