Skip to content

Commit a6d0623

Browse files
committed
Merge branch 'v3'
2 parents b7988b2 + f42d90a commit a6d0623

9 files changed

+21
-26
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ For more details, please see our paper [Attention, Learn to Solve Routing Proble
2121
* Python>=3.6
2222
* NumPy
2323
* SciPy
24-
* [PyTorch](http://pytorch.org/)=0.4
24+
* [PyTorch](http://pytorch.org/)>=1.1
2525
* tqdm
2626
* [tensorboard_logger](https://github.com/TeamHG-Memex/tensorboard_logger)
2727
* Matplotlib (optional, only for plotting)

nets/attention_model.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import torch
22
from torch import nn
3-
import torch.nn.functional as F
43
from torch.utils.checkpoint import checkpoint
54
import math
65
from typing import NamedTuple
@@ -131,7 +130,7 @@ def forward(self, input, return_pi=False):
131130
:return:
132131
"""
133132

134-
if self.checkpoint_encoder:
133+
if self.checkpoint_encoder and self.training: # Only checkpoint if we need gradients
135134
embeddings, _ = checkpoint(self.embedder, self._init_embed(input))
136135
else:
137136
embeddings, _ = self.embedder(self._init_embed(input))
@@ -360,7 +359,7 @@ def _get_log_p(self, fixed, state, normalize=True):
360359
log_p, glimpse = self._one_to_many_logits(query, glimpse_K, glimpse_V, logit_K, mask)
361360

362361
if normalize:
363-
log_p = F.log_softmax(log_p / self.temp, dim=-1)
362+
log_p = torch.log_softmax(log_p / self.temp, dim=-1)
364363

365364
assert not torch.isnan(log_p).any()
366365

@@ -465,7 +464,7 @@ def _one_to_many_logits(self, query, glimpse_K, glimpse_V, logit_K, mask):
465464
compatibility[mask[None, :, :, None, :].expand_as(compatibility)] = -math.inf
466465

467466
# Batch matrix multiplication to compute heads (n_heads, batch_size, num_steps, val_size)
468-
heads = torch.matmul(F.softmax(compatibility, dim=-1), glimpse_V)
467+
heads = torch.matmul(torch.softmax(compatibility, dim=-1), glimpse_V)
469468

470469
# Project to get glimpse/updated context node embedding (batch_size, num_steps, embedding_dim)
471470
glimpse = self.project_out(
@@ -480,7 +479,7 @@ def _one_to_many_logits(self, query, glimpse_K, glimpse_V, logit_K, mask):
480479

481480
# From the logits compute the probabilities by clipping, masking and softmax
482481
if self.tanh_clipping > 0:
483-
logits = F.tanh(logits) * self.tanh_clipping
482+
logits = torch.tanh(logits) * self.tanh_clipping
484483
if self.mask_logits:
485484
logits[mask] = -math.inf
486485

nets/graph_encoder.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import torch
2-
import torch.nn.functional as F
32
import numpy as np
43
from torch import nn
54
import math
@@ -95,7 +94,7 @@ def forward(self, q, h=None, mask=None):
9594
mask = mask.view(1, batch_size, n_query, graph_size).expand_as(compatibility)
9695
compatibility[mask] = -np.inf
9796

98-
attn = F.softmax(compatibility, dim=-1)
97+
attn = torch.softmax(compatibility, dim=-1)
9998

10099
# If there are nodes with no neighbours then softmax returns nan so we fix them to 0
101100
if mask is not None:

nets/pointer_network.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
import torch.nn as nn
33
from torch.autograd import Variable
4-
import torch.nn.functional as F
54
import math
65
import numpy as np
76

@@ -105,7 +104,7 @@ def recurrence(self, x, h_in, prev_mask, prev_idxs, step, context):
105104
logits, h_out = self.calc_logits(x, h_in, logit_mask, context, self.mask_glimpses, self.mask_logits)
106105

107106
# Calculate log_softmax for better numerical stability
108-
log_p = F.log_softmax(logits, dim=1)
107+
log_p = torch.log_softmax(logits, dim=1)
109108
probs = log_p.exp()
110109

111110
if not self.mask_logits:

problems/op/state_op.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -142,17 +142,15 @@ def get_mask(self):
142142
:return:
143143
"""
144144

145+
exceeds_length = (
146+
self.lengths[:, :, None] + (self.coords[self.ids, :, :] - self.cur_coord[:, :, None, :]).norm(p=2, dim=-1)
147+
> self.max_length[self.ids, :]
148+
)
145149
# Note: this always allows going to the depot, but that should always be suboptimal so be ok
146150
# Cannot visit if already visited or if length that would be upon arrival is too large to return to depot
147151
# If the depot has already been visited then we cannot visit anymore
148-
visited_ = self.visited
149-
mask = (
150-
visited_ | visited_[:, :, 0:1] |
151-
(
152-
self.lengths[:, :, None] + (self.coords[self.ids, :, :] - self.cur_coord[:, :, None, :]).norm(p=2, dim=-1)
153-
> self.max_length[self.ids, :]
154-
)
155-
)
152+
visited_ = self.visited.to(exceeds_length.dtype)
153+
mask = visited_ | visited_[:, :, 0:1] | exceeds_length
156154
# Depot can always be visited
157155
# (so we do not hardcode knowledge that this is strictly suboptimal if other options are available)
158156
mask[:, :, 0] = 0

problems/pctsp/state_pctsp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def get_mask(self):
162162
# Cannot visit depot if not yet collected 1 total prize and there are unvisited nodes
163163
mask[:, :, 0] = (self.cur_total_prize < 1.) & (visited_[:, :, 1:].int().sum(-1) < visited_[:, :, 1:].size(-1))
164164

165-
return mask
165+
return mask > 0 # Hacky way to return bool or uint8 depending on pytorch version
166166

167167
def construct_solutions(self, actions):
168168
return actions

problems/tsp/state_tsp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def get_current_node(self):
106106
return self.prev_a
107107

108108
def get_mask(self):
109-
return self.visited
109+
return self.visited > 0 # Hacky way to return bool or uint8 depending on pytorch version
110110

111111
def get_nn(self, k=None):
112112
# Insert step dimension

problems/vrp/state_cvrp.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -143,12 +143,10 @@ def get_mask(self):
143143
else:
144144
visited_loc = mask_long2bool(self.visited_, n=self.demand.size(-1))
145145

146+
# For demand steps_dim is inserted by indexing with id, for used_capacity insert node dim for broadcasting
147+
exceeds_cap = (self.demand[self.ids, :] + self.used_capacity[:, :, None] > self.VEHICLE_CAPACITY)
146148
# Nodes that cannot be visited are already visited or too much demand to be served now
147-
mask_loc = (
148-
visited_loc |
149-
# For demand steps_dim is inserted by indexing with id, for used_capacity insert node dim for broadcasting
150-
(self.demand[self.ids, :] + self.used_capacity[:, :, None] > self.VEHICLE_CAPACITY)
151-
)
149+
mask_loc = visited_loc.to(exceeds_cap.dtype) | exceeds_cap
152150

153151
# Cannot visit the depot if just visited and still unserved nodes
154152
mask_depot = (self.prev_a == 0) & ((mask_loc == 0).int().sum(-1) > 0)

train.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ def train_epoch(model, optimizer, baseline, lr_scheduler, epoch, val_dataset, pr
6868
print("Start train epoch {}, lr={} for run {}".format(epoch, optimizer.param_groups[0]['lr'], opts.run_name))
6969
step = epoch * (opts.epoch_size // opts.batch_size)
7070
start_time = time.time()
71-
lr_scheduler.step(epoch)
7271

7372
if not opts.no_tensorboard:
7473
tb_logger.log_value('learnrate_pg0', optimizer.param_groups[0]['lr'], step)
@@ -121,6 +120,9 @@ def train_epoch(model, optimizer, baseline, lr_scheduler, epoch, val_dataset, pr
121120

122121
baseline.epoch_callback(model, epoch)
123122

123+
# lr_scheduler should be called at end of epoch
124+
lr_scheduler.step()
125+
124126

125127
def train_batch(
126128
model,

0 commit comments

Comments
 (0)