-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
50 changed files
with
2,107 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
mkdir ~/envs | ||
virtualenv --no-download ~/envs/lop | ||
source ~/envs/lop/bin/activate | ||
pip3 install --no-index --upgrade pip | ||
pip3 install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cpu | ||
pip3 install -r requirements.txt | ||
pip3 install -e . |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import torch | ||
from math import sqrt | ||
|
||
|
||
class GnTredo(object): | ||
""" | ||
Generate-and-Test algorithm for feed forward neural networks, based on ReDo | ||
""" | ||
def __init__(self, net, hidden_activation, threshold=0.01, init='kaiming', device="cpu", reset_period=1000): | ||
super(GnTredo, self).__init__() | ||
self.device = device | ||
self.net = net | ||
self.num_hidden_layers = int(len(self.net)/2) | ||
self.threshold = threshold | ||
self.steps_since_last_redo = 0 | ||
self.reset_period = reset_period | ||
# Calculate uniform distribution's bound for random feature initialization | ||
if hidden_activation == 'selu': init = 'lecun' | ||
self.bounds = self.compute_bounds(hidden_activation=hidden_activation, init=init) | ||
|
||
def compute_bounds(self, hidden_activation, init='kaiming'): | ||
if hidden_activation in ['swish', 'elu']: hidden_activation = 'relu' | ||
if init == 'default': | ||
bounds = [sqrt(1 / self.net[i * 2].in_features) for i in range(self.num_hidden_layers)] | ||
elif init == 'xavier': | ||
bounds = [torch.nn.init.calculate_gain(nonlinearity=hidden_activation) * | ||
sqrt(6 / (self.net[i * 2].in_features + self.net[i * 2].out_features)) for i in | ||
range(self.num_hidden_layers)] | ||
elif init == 'lecun': | ||
bounds = [sqrt(3 / self.net[i * 2].in_features) for i in range(self.num_hidden_layers)] | ||
else: | ||
bounds = [torch.nn.init.calculate_gain(nonlinearity=hidden_activation) * | ||
sqrt(3 / self.net[i * 2].in_features) for i in range(self.num_hidden_layers)] | ||
bounds.append(1 * sqrt(3 / self.net[self.num_hidden_layers * 2].in_features)) | ||
return bounds | ||
|
||
def units_to_replace(self, features): | ||
""" | ||
Args: | ||
features: Activation values in the neural network, mini-batch * layer-idx * feature-idx | ||
Returns: | ||
Features to replace in each layer, Number of features to replace in each layer | ||
""" | ||
features = features.mean(dim=0) | ||
features_to_replace = [None]*self.num_hidden_layers | ||
num_features_to_replace = [None]*self.num_hidden_layers | ||
for i in range(self.num_hidden_layers): | ||
# Find features to replace | ||
feature_utility = features[i] / features[i].mean() | ||
new_features_to_replace = (feature_utility <= self.threshold).nonzero().reshape(-1) | ||
# Initialize utility for new features | ||
features_to_replace[i] = new_features_to_replace | ||
num_features_to_replace[i] = new_features_to_replace.shape[0] | ||
|
||
return features_to_replace, num_features_to_replace | ||
|
||
def gen_new_features(self, features_to_replace, num_features_to_replace): | ||
""" | ||
Generate new features: Reset input and output weights for low utility features | ||
""" | ||
with torch.no_grad(): | ||
for i in range(self.num_hidden_layers): | ||
if num_features_to_replace[i] == 0: | ||
continue | ||
current_layer = self.net[i * 2] | ||
next_layer = self.net[i * 2 + 2] | ||
current_layer.weight.data[features_to_replace[i], :] *= 0.0 | ||
current_layer.weight.data[features_to_replace[i], :] += \ | ||
torch.empty(num_features_to_replace[i], current_layer.in_features).uniform_( | ||
-self.bounds[i], self.bounds[i]).to(self.device) | ||
current_layer.bias.data[features_to_replace[i]] *= 0 | ||
|
||
next_layer.weight.data[:, features_to_replace[i]] = 0 | ||
|
||
def gen_and_test(self, features_history): | ||
""" | ||
Perform generate-and-test | ||
:param features: activation of hidden units in the neural network | ||
""" | ||
self.steps_since_last_redo += 1 | ||
if self.steps_since_last_redo < self.reset_period: | ||
return | ||
|
||
features_to_replace, num_features_to_replace = self.units_to_replace(features=features_history.abs()) | ||
self.gen_new_features(features_to_replace, num_features_to_replace) | ||
self.steps_since_last_redo = 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import torch | ||
|
||
|
||
class Agent(object): | ||
def __init__(self, pol, learner, device='cpu', to_log_features=False): | ||
self.pol = pol | ||
self.learner = learner | ||
self.device = device | ||
self.to_log_features = to_log_features | ||
|
||
def get_action(self, o): | ||
""" | ||
:param o: np. array of shape (1,) | ||
:return: a two tuple | ||
- np.array of shape (1,) | ||
- np.array of shape (1,) | ||
""" | ||
action, lprob, dist = self.pol.action(torch.tensor(o, dtype=torch.float32, device=self.device).unsqueeze(0), | ||
to_log_features=self.to_log_features) | ||
features = None | ||
if self.to_log_features: | ||
features = self.pol.get_activations() | ||
return action[0].cpu().numpy(), lprob.cpu().numpy(), self.pol.dist_to(dist, to_device='cpu'), features | ||
|
||
def log_update(self, o, a, r, op, logp, dist, done): | ||
return self.learner.log_update(o, a, r, op, logp, dist, done) | ||
|
||
def preprocess_state(self, state): | ||
return state | ||
|
||
def choose_action(self, o, epsilon): | ||
return self.get_action(o=o)[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import numpy as np | ||
import collections as c | ||
|
||
import torch | ||
|
||
|
||
class Buffer(object): | ||
def __init__(self, o_dim, a_dim, bs, device='cpu'): | ||
self.o_dim = o_dim | ||
self.a_dim = a_dim | ||
self.bs = bs | ||
self.device = device | ||
self.o_buf, self.a_buf, self.r_buf, self.logpb_buf, self.distb_buf, self.done_buf = \ | ||
c.deque(), c.deque(), c.deque(), c.deque(), c.deque(), c.deque() | ||
self.op = np.zeros((1, o_dim), dtype=np.float32) | ||
|
||
def store(self, o, a, r, op, logpb, dist, done): | ||
self.o_buf.append(o) | ||
self.a_buf.append(a) | ||
self.r_buf.append(r) | ||
self.logpb_buf.append(logpb) | ||
self.distb_buf.append(dist) | ||
self.done_buf.append(float(done)) | ||
self.op[:] = op | ||
|
||
def pop(self): | ||
self.o_buf.popleft() | ||
self.a_buf.popleft() | ||
self.r_buf.popleft() | ||
self.logpb_buf.popleft() | ||
self.distb_buf.popleft() | ||
self.done_buf.popleft() | ||
|
||
def clear(self): | ||
self.o_buf.clear() | ||
self.a_buf.clear() | ||
self.r_buf.clear() | ||
self.logpb_buf.clear() | ||
self.distb_buf.clear() | ||
self.done_buf.clear() | ||
|
||
def get(self, dist_stack): | ||
rang = range(self.bs) | ||
os = torch.as_tensor(np.array([self.o_buf[i] for i in rang]), dtype=torch.float32, device=self.device).view(-1, self.o_dim) | ||
acts = torch.as_tensor(np.array([self.a_buf[i] for i in rang]), dtype=torch.float32, device=self.device).view(-1, self.a_dim) | ||
rs = torch.as_tensor(np.array([self.r_buf[i] for i in rang]), dtype=torch.float32, device=self.device).view(-1, 1) | ||
op = torch.as_tensor(self.op, device=self.device).view(-1, self.o_dim) | ||
logpbs = torch.as_tensor(np.array([self.logpb_buf[i] for i in rang]), dtype=torch.float32, device=self.device).view(-1, 1) | ||
distbs = dist_stack([self.distb_buf[i] for i in rang], device=self.device) | ||
dones = torch.as_tensor(np.array([self.done_buf[i] for i in rang]), dtype=torch.float32, device=self.device).view(-1, 1) | ||
|
||
return os, acts, rs, op, logpbs, distbs, dones |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
class Learner(object): | ||
def log_update(self, o, a, r, op, logpb, dist, done): | ||
self.log(o, a, r, op, logpb, dist, done) | ||
info0 = {'learned': False} | ||
if self.learn_time(done): | ||
info = self.learn() | ||
self.post_learn() | ||
info0.update(info) | ||
info0['learned'] = True | ||
return info0 | ||
|
||
def log(self, o, a, r, op, logpb, dist, done): | ||
pass | ||
|
||
def learn_time(self, done): | ||
pass | ||
|
||
def post_learn(self): | ||
pass | ||
|
||
def learn(self, env=None): | ||
pass |
Oops, something went wrong.