Skip to content

Commit 21e5a00

Browse files
authored
Merge pull request #25 from wouterkool/v4
Update to Python 3.8, PyTorch 1.7, some bug fixes
2 parents ffd5b86 + 5fa0b17 commit 21e5a00

8 files changed

+82
-84
lines changed

nets/attention_model.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,14 @@ class AttentionModelFixed(NamedTuple):
2929
logit_key: torch.Tensor
3030

3131
def __getitem__(self, key):
32-
if torch.is_tensor(key) or isinstance(key, slice):
33-
return AttentionModelFixed(
34-
node_embeddings=self.node_embeddings[key],
35-
context_node_projected=self.context_node_projected[key],
36-
glimpse_key=self.glimpse_key[:, key], # dim 0 are the heads
37-
glimpse_val=self.glimpse_val[:, key], # dim 0 are the heads
38-
logit_key=self.logit_key[key]
39-
)
40-
return super(AttentionModelFixed, self).__getitem__(key)
32+
assert torch.is_tensor(key) or isinstance(key, slice)
33+
return AttentionModelFixed(
34+
node_embeddings=self.node_embeddings[key],
35+
context_node_projected=self.context_node_projected[key],
36+
glimpse_key=self.glimpse_key[:, key], # dim 0 are the heads
37+
glimpse_val=self.glimpse_val[:, key], # dim 0 are the heads
38+
logit_key=self.logit_key[key]
39+
)
4140

4241

4342
class AttentionModel(nn.Module):
@@ -172,7 +171,7 @@ def propose_expansions(self, beam, fixed, expand_size=None, normalize=False, max
172171
flat_feas = flat_score > -1e10 # != -math.inf triggers
173172

174173
# Parent is row idx of ind_topk, can be found by enumerating elements and dividing by number of columns
175-
flat_parent = torch.arange(flat_action.size(-1), out=flat_action.new()) / ind_topk.size(-1)
174+
flat_parent = torch.arange(flat_action.size(-1), out=flat_action.new()) // ind_topk.size(-1)
176175

177176
# Filter infeasible
178177
feas_ind_2d = torch.nonzero(flat_feas)

nets/graph_encoder.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,13 @@ def __init__(
1919
self,
2020
n_heads,
2121
input_dim,
22-
embed_dim=None,
22+
embed_dim,
2323
val_dim=None,
2424
key_dim=None
2525
):
2626
super(MultiHeadAttention, self).__init__()
2727

2828
if val_dim is None:
29-
assert embed_dim is not None, "Provide either embed_dim or val_dim"
3029
val_dim = embed_dim // n_heads
3130
if key_dim is None:
3231
key_dim = val_dim
@@ -43,8 +42,7 @@ def __init__(
4342
self.W_key = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim))
4443
self.W_val = nn.Parameter(torch.Tensor(n_heads, input_dim, val_dim))
4544

46-
if embed_dim is not None:
47-
self.W_out = nn.Parameter(torch.Tensor(n_heads, key_dim, embed_dim))
45+
self.W_out = nn.Parameter(torch.Tensor(n_heads, val_dim, embed_dim))
4846

4947
self.init_parameters()
5048

@@ -109,6 +107,15 @@ def forward(self, q, h=None, mask=None):
109107
self.W_out.view(-1, self.embed_dim)
110108
).view(batch_size, n_query, self.embed_dim)
111109

110+
# Alternative:
111+
# headst = heads.transpose(0, 1) # swap the dimensions for batch and heads to align it for the matmul
112+
# # proj_h = torch.einsum('bhni,hij->bhnj', headst, self.W_out)
113+
# projected_heads = torch.matmul(headst, self.W_out)
114+
# out = torch.sum(projected_heads, dim=1) # sum across heads
115+
116+
# Or:
117+
# out = torch.einsum('hbni,hij->bnj', heads, self.W_out)
118+
112119
return out
113120

114121

problems/op/state_op.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,15 @@ def dist(self):
3636
return (self.coords[:, :, None, :] - self.coords[:, None, :, :]).norm(p=2, dim=-1)
3737

3838
def __getitem__(self, key):
39-
if torch.is_tensor(key) or isinstance(key, slice): # If tensor, idx all tensors by this tensor:
40-
return self._replace(
41-
ids=self.ids[key],
42-
prev_a=self.prev_a[key],
43-
visited_=self.visited_[key],
44-
lengths=self.lengths[key],
45-
cur_coord=self.cur_coord[key],
46-
cur_total_prize=self.cur_total_prize[key],
47-
)
48-
return super(StateOP, self).__getitem__(key)
39+
assert torch.is_tensor(key) or isinstance(key, slice) # If tensor, idx all tensors by this tensor:
40+
return self._replace(
41+
ids=self.ids[key],
42+
prev_a=self.prev_a[key],
43+
visited_=self.visited_[key],
44+
lengths=self.lengths[key],
45+
cur_coord=self.cur_coord[key],
46+
cur_total_prize=self.cur_total_prize[key],
47+
)
4948

5049
# Warning: cannot override len of NamedTuple, len should be number of fields, not batch size
5150
# def __len__(self):

problems/pctsp/state_pctsp.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,16 @@ def dist(self):
3636
return (self.coords[:, :, None, :] - self.coords[:, None, :, :]).norm(p=2, dim=-1)
3737

3838
def __getitem__(self, key):
39-
if torch.is_tensor(key) or isinstance(key, slice): # If tensor, idx all tensors by this tensor:
40-
return self._replace(
41-
ids=self.ids[key],
42-
prev_a=self.prev_a[key],
43-
visited_=self.visited_[key],
44-
lengths=self.lengths[key],
45-
cur_total_prize=self.cur_total_prize[key],
46-
cur_total_penalty=self.cur_total_penalty[key],
47-
cur_coord=self.cur_coord[key],
48-
)
49-
return super(StatePCTSP, self).__getitem__(key)
39+
assert torch.is_tensor(key) or isinstance(key, slice) # If tensor, idx all tensors by this tensor:
40+
return self._replace(
41+
ids=self.ids[key],
42+
prev_a=self.prev_a[key],
43+
visited_=self.visited_[key],
44+
lengths=self.lengths[key],
45+
cur_total_prize=self.cur_total_prize[key],
46+
cur_total_penalty=self.cur_total_penalty[key],
47+
cur_coord=self.cur_coord[key],
48+
)
5049

5150
# Warning: cannot override len of NamedTuple, len should be number of fields, not batch size
5251
# def __len__(self):

problems/tsp/state_tsp.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,15 @@ def visited(self):
2828
return mask_long2bool(self.visited_, n=self.loc.size(-2))
2929

3030
def __getitem__(self, key):
31-
if torch.is_tensor(key) or isinstance(key, slice): # If tensor, idx all tensors by this tensor:
32-
return self._replace(
33-
ids=self.ids[key],
34-
first_a=self.first_a[key],
35-
prev_a=self.prev_a[key],
36-
visited_=self.visited_[key],
37-
lengths=self.lengths[key],
38-
cur_coord=self.cur_coord[key] if self.cur_coord is not None else None,
39-
)
40-
return super(StateTSP, self).__getitem__(key)
31+
assert torch.is_tensor(key) or isinstance(key, slice) # If tensor, idx all tensors by this tensor:
32+
return self._replace(
33+
ids=self.ids[key],
34+
first_a=self.first_a[key],
35+
prev_a=self.prev_a[key],
36+
visited_=self.visited_[key],
37+
lengths=self.lengths[key],
38+
cur_coord=self.cur_coord[key] if self.cur_coord is not None else None,
39+
)
4140

4241
@staticmethod
4342
def initialize(loc, visited_dtype=torch.uint8):

problems/vrp/state_cvrp.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,15 @@ def dist(self):
3434
return (self.coords[:, :, None, :] - self.coords[:, None, :, :]).norm(p=2, dim=-1)
3535

3636
def __getitem__(self, key):
37-
if torch.is_tensor(key) or isinstance(key, slice): # If tensor, idx all tensors by this tensor:
38-
return self._replace(
39-
ids=self.ids[key],
40-
prev_a=self.prev_a[key],
41-
used_capacity=self.used_capacity[key],
42-
visited_=self.visited_[key],
43-
lengths=self.lengths[key],
44-
cur_coord=self.cur_coord[key],
45-
)
46-
return super(StateCVRP, self).__getitem__(key)
37+
assert torch.is_tensor(key) or isinstance(key, slice) # If tensor, idx all tensors by this tensor:
38+
return self._replace(
39+
ids=self.ids[key],
40+
prev_a=self.prev_a[key],
41+
used_capacity=self.used_capacity[key],
42+
visited_=self.visited_[key],
43+
lengths=self.lengths[key],
44+
cur_coord=self.cur_coord[key],
45+
)
4746

4847
# Warning: cannot override len of NamedTuple, len should be number of fields, not batch size
4948
# def __len__(self):

problems/vrp/state_sdvrp.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,15 @@ class StateSDVRP(NamedTuple):
2222
VEHICLE_CAPACITY = 1.0 # Hardcoded
2323

2424
def __getitem__(self, key):
25-
if torch.is_tensor(key) or isinstance(key, slice): # If tensor, idx all tensors by this tensor:
26-
return self._replace(
27-
ids=self.ids[key],
28-
prev_a=self.prev_a[key],
29-
used_capacity=self.used_capacity[key],
30-
demands_with_depot=self.demands_with_depot[key],
31-
lengths=self.lengths[key],
32-
cur_coord=self.cur_coord[key],
33-
)
34-
return super(StateSDVRP, self).__getitem__(key)
25+
assert torch.is_tensor(key) or isinstance(key, slice) # If tensor, idx all tensors by this tensor:
26+
return self._replace(
27+
ids=self.ids[key],
28+
prev_a=self.prev_a[key],
29+
used_capacity=self.used_capacity[key],
30+
demands_with_depot=self.demands_with_depot[key],
31+
lengths=self.lengths[key],
32+
cur_coord=self.cur_coord[key],
33+
)
3534

3635
@staticmethod
3736
def initialize(input):

utils/beam_search.py

+16-19
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,14 @@ def ids(self):
7070
return self.state.ids.view(-1) # Need to flat as state has steps dimension
7171

7272
def __getitem__(self, key):
73-
if torch.is_tensor(key) or isinstance(key, slice): # If tensor, idx all tensors by this tensor:
74-
return self._replace(
75-
# ids=self.ids[key],
76-
score=self.score[key] if self.score is not None else None,
77-
state=self.state[key],
78-
parent=self.parent[key] if self.parent is not None else None,
79-
action=self.action[key] if self.action is not None else None
80-
)
81-
return super(BatchBeam, self).__getitem__(key)
73+
assert torch.is_tensor(key) or isinstance(key, slice) # If tensor, idx all tensors by this tensor:
74+
return self._replace(
75+
# ids=self.ids[key],
76+
score=self.score[key] if self.score is not None else None,
77+
state=self.state[key],
78+
parent=self.parent[key] if self.parent is not None else None,
79+
action=self.action[key] if self.action is not None else None
80+
)
8281

8382
# Do not use __len__ since this is used by namedtuple internally and should be number of fields
8483
# def __len__(self):
@@ -207,15 +206,13 @@ def __getitem__(self, key):
207206
assert not isinstance(key, slice), "CachedLookup does not support slicing, " \
208207
"you can slice the result of an index operation instead"
209208

210-
if torch.is_tensor(key): # If tensor, idx all tensors by this tensor:
211-
212-
if self.key is None:
213-
self.key = key
214-
self.current = self.orig[key]
215-
elif len(key) != len(self.key) or (key != self.key).any():
216-
self.key = key
217-
self.current = self.orig[key]
209+
assert torch.is_tensor(key) # If tensor, idx all tensors by this tensor:
218210

219-
return self.current
211+
if self.key is None:
212+
self.key = key
213+
self.current = self.orig[key]
214+
elif len(key) != len(self.key) or (key != self.key).any():
215+
self.key = key
216+
self.current = self.orig[key]
220217

221-
return super(CachedLookup, self).__getitem__(key)
218+
return self.current

0 commit comments

Comments
 (0)