Skip to content

Commit

Permalink
Remove super call in __getitem__ which causes problems in Python 3.8
Browse files Browse the repository at this point in the history
  • Loading branch information
wouterkool committed Nov 20, 2020
1 parent ffd5b86 commit 6c5b78e
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 79 deletions.
17 changes: 8 additions & 9 deletions nets/attention_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,14 @@ class AttentionModelFixed(NamedTuple):
logit_key: torch.Tensor

def __getitem__(self, key):
if torch.is_tensor(key) or isinstance(key, slice):
return AttentionModelFixed(
node_embeddings=self.node_embeddings[key],
context_node_projected=self.context_node_projected[key],
glimpse_key=self.glimpse_key[:, key], # dim 0 are the heads
glimpse_val=self.glimpse_val[:, key], # dim 0 are the heads
logit_key=self.logit_key[key]
)
return super(AttentionModelFixed, self).__getitem__(key)
assert torch.is_tensor(key) or isinstance(key, slice)
return AttentionModelFixed(
node_embeddings=self.node_embeddings[key],
context_node_projected=self.context_node_projected[key],
glimpse_key=self.glimpse_key[:, key], # dim 0 are the heads
glimpse_val=self.glimpse_val[:, key], # dim 0 are the heads
logit_key=self.logit_key[key]
)


class AttentionModel(nn.Module):
Expand Down
19 changes: 9 additions & 10 deletions problems/op/state_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,15 @@ def dist(self):
return (self.coords[:, :, None, :] - self.coords[:, None, :, :]).norm(p=2, dim=-1)

def __getitem__(self, key):
if torch.is_tensor(key) or isinstance(key, slice): # If tensor, idx all tensors by this tensor:
return self._replace(
ids=self.ids[key],
prev_a=self.prev_a[key],
visited_=self.visited_[key],
lengths=self.lengths[key],
cur_coord=self.cur_coord[key],
cur_total_prize=self.cur_total_prize[key],
)
return super(StateOP, self).__getitem__(key)
assert torch.is_tensor(key) or isinstance(key, slice) # If tensor, idx all tensors by this tensor:
return self._replace(
ids=self.ids[key],
prev_a=self.prev_a[key],
visited_=self.visited_[key],
lengths=self.lengths[key],
cur_coord=self.cur_coord[key],
cur_total_prize=self.cur_total_prize[key],
)

# Warning: cannot override len of NamedTuple, len should be number of fields, not batch size
# def __len__(self):
Expand Down
21 changes: 10 additions & 11 deletions problems/pctsp/state_pctsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,16 @@ def dist(self):
return (self.coords[:, :, None, :] - self.coords[:, None, :, :]).norm(p=2, dim=-1)

def __getitem__(self, key):
if torch.is_tensor(key) or isinstance(key, slice): # If tensor, idx all tensors by this tensor:
return self._replace(
ids=self.ids[key],
prev_a=self.prev_a[key],
visited_=self.visited_[key],
lengths=self.lengths[key],
cur_total_prize=self.cur_total_prize[key],
cur_total_penalty=self.cur_total_penalty[key],
cur_coord=self.cur_coord[key],
)
return super(StatePCTSP, self).__getitem__(key)
assert torch.is_tensor(key) or isinstance(key, slice) # If tensor, idx all tensors by this tensor:
return self._replace(
ids=self.ids[key],
prev_a=self.prev_a[key],
visited_=self.visited_[key],
lengths=self.lengths[key],
cur_total_prize=self.cur_total_prize[key],
cur_total_penalty=self.cur_total_penalty[key],
cur_coord=self.cur_coord[key],
)

# Warning: cannot override len of NamedTuple, len should be number of fields, not batch size
# def __len__(self):
Expand Down
19 changes: 9 additions & 10 deletions problems/tsp/state_tsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,15 @@ def visited(self):
return mask_long2bool(self.visited_, n=self.loc.size(-2))

def __getitem__(self, key):
if torch.is_tensor(key) or isinstance(key, slice): # If tensor, idx all tensors by this tensor:
return self._replace(
ids=self.ids[key],
first_a=self.first_a[key],
prev_a=self.prev_a[key],
visited_=self.visited_[key],
lengths=self.lengths[key],
cur_coord=self.cur_coord[key] if self.cur_coord is not None else None,
)
return super(StateTSP, self).__getitem__(key)
assert torch.is_tensor(key) or isinstance(key, slice) # If tensor, idx all tensors by this tensor:
return self._replace(
ids=self.ids[key],
first_a=self.first_a[key],
prev_a=self.prev_a[key],
visited_=self.visited_[key],
lengths=self.lengths[key],
cur_coord=self.cur_coord[key] if self.cur_coord is not None else None,
)

@staticmethod
def initialize(loc, visited_dtype=torch.uint8):
Expand Down
19 changes: 9 additions & 10 deletions problems/vrp/state_cvrp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,15 @@ def dist(self):
return (self.coords[:, :, None, :] - self.coords[:, None, :, :]).norm(p=2, dim=-1)

def __getitem__(self, key):
if torch.is_tensor(key) or isinstance(key, slice): # If tensor, idx all tensors by this tensor:
return self._replace(
ids=self.ids[key],
prev_a=self.prev_a[key],
used_capacity=self.used_capacity[key],
visited_=self.visited_[key],
lengths=self.lengths[key],
cur_coord=self.cur_coord[key],
)
return super(StateCVRP, self).__getitem__(key)
assert torch.is_tensor(key) or isinstance(key, slice) # If tensor, idx all tensors by this tensor:
return self._replace(
ids=self.ids[key],
prev_a=self.prev_a[key],
used_capacity=self.used_capacity[key],
visited_=self.visited_[key],
lengths=self.lengths[key],
cur_coord=self.cur_coord[key],
)

# Warning: cannot override len of NamedTuple, len should be number of fields, not batch size
# def __len__(self):
Expand Down
19 changes: 9 additions & 10 deletions problems/vrp/state_sdvrp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,15 @@ class StateSDVRP(NamedTuple):
VEHICLE_CAPACITY = 1.0 # Hardcoded

def __getitem__(self, key):
if torch.is_tensor(key) or isinstance(key, slice): # If tensor, idx all tensors by this tensor:
return self._replace(
ids=self.ids[key],
prev_a=self.prev_a[key],
used_capacity=self.used_capacity[key],
demands_with_depot=self.demands_with_depot[key],
lengths=self.lengths[key],
cur_coord=self.cur_coord[key],
)
return super(StateSDVRP, self).__getitem__(key)
assert torch.is_tensor(key) or isinstance(key, slice) # If tensor, idx all tensors by this tensor:
return self._replace(
ids=self.ids[key],
prev_a=self.prev_a[key],
used_capacity=self.used_capacity[key],
demands_with_depot=self.demands_with_depot[key],
lengths=self.lengths[key],
cur_coord=self.cur_coord[key],
)

@staticmethod
def initialize(input):
Expand Down
35 changes: 16 additions & 19 deletions utils/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,14 @@ def ids(self):
return self.state.ids.view(-1) # Need to flat as state has steps dimension

def __getitem__(self, key):
if torch.is_tensor(key) or isinstance(key, slice): # If tensor, idx all tensors by this tensor:
return self._replace(
# ids=self.ids[key],
score=self.score[key] if self.score is not None else None,
state=self.state[key],
parent=self.parent[key] if self.parent is not None else None,
action=self.action[key] if self.action is not None else None
)
return super(BatchBeam, self).__getitem__(key)
assert torch.is_tensor(key) or isinstance(key, slice) # If tensor, idx all tensors by this tensor:
return self._replace(
# ids=self.ids[key],
score=self.score[key] if self.score is not None else None,
state=self.state[key],
parent=self.parent[key] if self.parent is not None else None,
action=self.action[key] if self.action is not None else None
)

# Do not use __len__ since this is used by namedtuple internally and should be number of fields
# def __len__(self):
Expand Down Expand Up @@ -207,15 +206,13 @@ def __getitem__(self, key):
assert not isinstance(key, slice), "CachedLookup does not support slicing, " \
"you can slice the result of an index operation instead"

if torch.is_tensor(key): # If tensor, idx all tensors by this tensor:

if self.key is None:
self.key = key
self.current = self.orig[key]
elif len(key) != len(self.key) or (key != self.key).any():
self.key = key
self.current = self.orig[key]
assert torch.is_tensor(key) # If tensor, idx all tensors by this tensor:

return self.current
if self.key is None:
self.key = key
self.current = self.orig[key]
elif len(key) != len(self.key) or (key != self.key).any():
self.key = key
self.current = self.orig[key]

return super(CachedLookup, self).__getitem__(key)
return self.current

0 comments on commit 6c5b78e

Please sign in to comment.