Skip to content

Commit

Permalink
Updated to latest PyTorch
Browse files Browse the repository at this point in the history
  • Loading branch information
wouterkool committed Dec 10, 2019
1 parent f563e33 commit f42d90a
Show file tree
Hide file tree
Showing 9 changed files with 22 additions and 27 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ For more details, please see our paper [Attention, Learn to Solve Routing Proble

## Dependencies

* Python>=3.5
* Python>=3.6
* NumPy
* SciPy
* [PyTorch](http://pytorch.org/)=0.4
* [PyTorch](http://pytorch.org/)>=1.1
* tqdm
* [tensorboard_logger](https://github.com/TeamHG-Memex/tensorboard_logger)
* Matplotlib (optional, only for plotting)
Expand Down
9 changes: 4 additions & 5 deletions nets/attention_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
import math
from typing import NamedTuple
Expand Down Expand Up @@ -131,7 +130,7 @@ def forward(self, input, return_pi=False):
:return:
"""

if self.checkpoint_encoder:
if self.checkpoint_encoder and self.training: # Only checkpoint if we need gradients
embeddings, _ = checkpoint(self.embedder, self._init_embed(input))
else:
embeddings, _ = self.embedder(self._init_embed(input))
Expand Down Expand Up @@ -360,7 +359,7 @@ def _get_log_p(self, fixed, state, normalize=True):
log_p, glimpse = self._one_to_many_logits(query, glimpse_K, glimpse_V, logit_K, mask)

if normalize:
log_p = F.log_softmax(log_p / self.temp, dim=-1)
log_p = torch.log_softmax(log_p / self.temp, dim=-1)

assert not torch.isnan(log_p).any()

Expand Down Expand Up @@ -465,7 +464,7 @@ def _one_to_many_logits(self, query, glimpse_K, glimpse_V, logit_K, mask):
compatibility[mask[None, :, :, None, :].expand_as(compatibility)] = -math.inf

# Batch matrix multiplication to compute heads (n_heads, batch_size, num_steps, val_size)
heads = torch.matmul(F.softmax(compatibility, dim=-1), glimpse_V)
heads = torch.matmul(torch.softmax(compatibility, dim=-1), glimpse_V)

# Project to get glimpse/updated context node embedding (batch_size, num_steps, embedding_dim)
glimpse = self.project_out(
Expand All @@ -480,7 +479,7 @@ def _one_to_many_logits(self, query, glimpse_K, glimpse_V, logit_K, mask):

# From the logits compute the probabilities by clipping, masking and softmax
if self.tanh_clipping > 0:
logits = F.tanh(logits) * self.tanh_clipping
logits = torch.tanh(logits) * self.tanh_clipping
if self.mask_logits:
logits[mask] = -math.inf

Expand Down
3 changes: 1 addition & 2 deletions nets/graph_encoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
import torch.nn.functional as F
import numpy as np
from torch import nn
import math
Expand Down Expand Up @@ -95,7 +94,7 @@ def forward(self, q, h=None, mask=None):
mask = mask.view(1, batch_size, n_query, graph_size).expand_as(compatibility)
compatibility[mask] = -np.inf

attn = F.softmax(compatibility, dim=-1)
attn = torch.softmax(compatibility, dim=-1)

# If there are nodes with no neighbours then softmax returns nan so we fix them to 0
if mask is not None:
Expand Down
3 changes: 1 addition & 2 deletions nets/pointer_network.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import math
import numpy as np

Expand Down Expand Up @@ -105,7 +104,7 @@ def recurrence(self, x, h_in, prev_mask, prev_idxs, step, context):
logits, h_out = self.calc_logits(x, h_in, logit_mask, context, self.mask_glimpses, self.mask_logits)

# Calculate log_softmax for better numerical stability
log_p = F.log_softmax(logits, dim=1)
log_p = torch.log_softmax(logits, dim=1)
probs = log_p.exp()

if not self.mask_logits:
Expand Down
14 changes: 6 additions & 8 deletions problems/op/state_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,15 @@ def get_mask(self):
:return:
"""

exceeds_length = (
self.lengths[:, :, None] + (self.coords[self.ids, :, :] - self.cur_coord[:, :, None, :]).norm(p=2, dim=-1)
> self.max_length[self.ids, :]
)
# Note: this always allows going to the depot, but that should always be suboptimal so be ok
# Cannot visit if already visited or if length that would be upon arrival is too large to return to depot
# If the depot has already been visited then we cannot visit anymore
visited_ = self.visited
mask = (
visited_ | visited_[:, :, 0:1] |
(
self.lengths[:, :, None] + (self.coords[self.ids, :, :] - self.cur_coord[:, :, None, :]).norm(p=2, dim=-1)
> self.max_length[self.ids, :]
)
)
visited_ = self.visited.to(exceeds_length.dtype)
mask = visited_ | visited_[:, :, 0:1] | exceeds_length
# Depot can always be visited
# (so we do not hardcode knowledge that this is strictly suboptimal if other options are available)
mask[:, :, 0] = 0
Expand Down
2 changes: 1 addition & 1 deletion problems/pctsp/state_pctsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def get_mask(self):
# Cannot visit depot if not yet collected 1 total prize and there are unvisited nodes
mask[:, :, 0] = (self.cur_total_prize < 1.) & (visited_[:, :, 1:].int().sum(-1) < visited_[:, :, 1:].size(-1))

return mask
return mask > 0 # Hacky way to return bool or uint8 depending on pytorch version

def construct_solutions(self, actions):
return actions
2 changes: 1 addition & 1 deletion problems/tsp/state_tsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def get_current_node(self):
return self.prev_a

def get_mask(self):
return self.visited
return self.visited > 0 # Hacky way to return bool or uint8 depending on pytorch version

def get_nn(self, k=None):
# Insert step dimension
Expand Down
8 changes: 3 additions & 5 deletions problems/vrp/state_cvrp.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,10 @@ def get_mask(self):
else:
visited_loc = mask_long2bool(self.visited_, n=self.demand.size(-1))

# For demand steps_dim is inserted by indexing with id, for used_capacity insert node dim for broadcasting
exceeds_cap = (self.demand[self.ids, :] + self.used_capacity[:, :, None] > self.VEHICLE_CAPACITY)
# Nodes that cannot be visited are already visited or too much demand to be served now
mask_loc = (
visited_loc |
# For demand steps_dim is inserted by indexing with id, for used_capacity insert node dim for broadcasting
(self.demand[self.ids, :] + self.used_capacity[:, :, None] > self.VEHICLE_CAPACITY)
)
mask_loc = visited_loc.to(exceeds_cap.dtype) | exceeds_cap

# Cannot visit the depot if just visited and still unserved nodes
mask_depot = (self.prev_a == 0) & ((mask_loc == 0).int().sum(-1) > 0)
Expand Down
4 changes: 3 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def train_epoch(model, optimizer, baseline, lr_scheduler, epoch, val_dataset, pr
print("Start train epoch {}, lr={} for run {}".format(epoch, optimizer.param_groups[0]['lr'], opts.run_name))
step = epoch * (opts.epoch_size // opts.batch_size)
start_time = time.time()
lr_scheduler.step(epoch)

if not opts.no_tensorboard:
tb_logger.log_value('learnrate_pg0', optimizer.param_groups[0]['lr'], step)
Expand Down Expand Up @@ -121,6 +120,9 @@ def train_epoch(model, optimizer, baseline, lr_scheduler, epoch, val_dataset, pr

baseline.epoch_callback(model, epoch)

# lr_scheduler should be called at end of epoch
lr_scheduler.step()


def train_batch(
model,
Expand Down

0 comments on commit f42d90a

Please sign in to comment.