@@ -29,15 +29,14 @@ class AttentionModelFixed(NamedTuple):
29
29
logit_key : torch .Tensor
30
30
31
31
def __getitem__ (self , key ):
32
- if torch .is_tensor (key ) or isinstance (key , slice ):
33
- return AttentionModelFixed (
34
- node_embeddings = self .node_embeddings [key ],
35
- context_node_projected = self .context_node_projected [key ],
36
- glimpse_key = self .glimpse_key [:, key ], # dim 0 are the heads
37
- glimpse_val = self .glimpse_val [:, key ], # dim 0 are the heads
38
- logit_key = self .logit_key [key ]
39
- )
40
- return super (AttentionModelFixed , self ).__getitem__ (key )
32
+ assert torch .is_tensor (key ) or isinstance (key , slice )
33
+ return AttentionModelFixed (
34
+ node_embeddings = self .node_embeddings [key ],
35
+ context_node_projected = self .context_node_projected [key ],
36
+ glimpse_key = self .glimpse_key [:, key ], # dim 0 are the heads
37
+ glimpse_val = self .glimpse_val [:, key ], # dim 0 are the heads
38
+ logit_key = self .logit_key [key ]
39
+ )
41
40
42
41
43
42
class AttentionModel (nn .Module ):
@@ -172,7 +171,7 @@ def propose_expansions(self, beam, fixed, expand_size=None, normalize=False, max
172
171
flat_feas = flat_score > - 1e10 # != -math.inf triggers
173
172
174
173
# Parent is row idx of ind_topk, can be found by enumerating elements and dividing by number of columns
175
- flat_parent = torch .arange (flat_action .size (- 1 ), out = flat_action .new ()) / ind_topk .size (- 1 )
174
+ flat_parent = torch .arange (flat_action .size (- 1 ), out = flat_action .new ()) // ind_topk .size (- 1 )
176
175
177
176
# Filter infeasible
178
177
feas_ind_2d = torch .nonzero (flat_feas )
0 commit comments