1
1
import torch
2
2
from torch import nn
3
- import torch .nn .functional as F
4
3
from torch .utils .checkpoint import checkpoint
5
4
import math
6
5
from typing import NamedTuple
@@ -131,7 +130,7 @@ def forward(self, input, return_pi=False):
131
130
:return:
132
131
"""
133
132
134
- if self .checkpoint_encoder :
133
+ if self .checkpoint_encoder and self . training : # Only checkpoint if we need gradients
135
134
embeddings , _ = checkpoint (self .embedder , self ._init_embed (input ))
136
135
else :
137
136
embeddings , _ = self .embedder (self ._init_embed (input ))
@@ -360,7 +359,7 @@ def _get_log_p(self, fixed, state, normalize=True):
360
359
log_p , glimpse = self ._one_to_many_logits (query , glimpse_K , glimpse_V , logit_K , mask )
361
360
362
361
if normalize :
363
- log_p = F .log_softmax (log_p / self .temp , dim = - 1 )
362
+ log_p = torch .log_softmax (log_p / self .temp , dim = - 1 )
364
363
365
364
assert not torch .isnan (log_p ).any ()
366
365
@@ -465,7 +464,7 @@ def _one_to_many_logits(self, query, glimpse_K, glimpse_V, logit_K, mask):
465
464
compatibility [mask [None , :, :, None , :].expand_as (compatibility )] = - math .inf
466
465
467
466
# Batch matrix multiplication to compute heads (n_heads, batch_size, num_steps, val_size)
468
- heads = torch .matmul (F .softmax (compatibility , dim = - 1 ), glimpse_V )
467
+ heads = torch .matmul (torch .softmax (compatibility , dim = - 1 ), glimpse_V )
469
468
470
469
# Project to get glimpse/updated context node embedding (batch_size, num_steps, embedding_dim)
471
470
glimpse = self .project_out (
@@ -480,7 +479,7 @@ def _one_to_many_logits(self, query, glimpse_K, glimpse_V, logit_K, mask):
480
479
481
480
# From the logits compute the probabilities by clipping, masking and softmax
482
481
if self .tanh_clipping > 0 :
483
- logits = F .tanh (logits ) * self .tanh_clipping
482
+ logits = torch .tanh (logits ) * self .tanh_clipping
484
483
if self .mask_logits :
485
484
logits [mask ] = - math .inf
486
485
0 commit comments