Skip to content

Commit

Permalink
add RL results
Browse files Browse the repository at this point in the history
  • Loading branch information
qlan3 committed Jan 11, 2024
1 parent b1f6bed commit bdfe0dd
Show file tree
Hide file tree
Showing 50 changed files with 2,107 additions and 32 deletions.
17 changes: 10 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Loss of Plasticity in Deep Continual Learning

This repository contains the implementation of three continual supervised learning problems.
In our forthcoming paper _Maintaining Plasticity in Deep Continual Learning_,
In our forthcoming paper _Loss Plasticity in Deep Continual Learning_,
we show the loss of plasticity in deep learning in these problems.

A talk about this work can be found [here](https://www.youtube.com/watch?v=p_zknyfV9fY),
Expand All @@ -9,15 +10,17 @@ and the [paper](https://arxiv.org/abs/2306.13812) is available on arxiv.
# Installation

```sh
virtualenv --python=/usr/bin/python3.8 loss-of-plasticity/
source loss-of-plasticity/bin/activate
mkdir ~/envs
virtualenv --no-download --python=/usr/bin/python3.8 ~/envs/lop
source ~/envs/lop/bin/activate
pip3 install --no-index --upgrade pip
git clone https://github.com/shibhansh/loss-of-plasticity.git
cd loss-of-plasticity
pip3 install -r requirements.txt
pip3 install -e .
```

Add these lines in your .zshrc
Add these lines in your ~/.zshrc or ~/.bashrc
```sh
source PATH_TO_DIR/loss-of-plasticity/lop/bin/activate
export PYTHONPATH=$PATH:PATH_TO_DIR/lop
```
source ~/envs/lop/bin/activate
```
7 changes: 7 additions & 0 deletions install.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
mkdir ~/envs
virtualenv --no-download ~/envs/lop
source ~/envs/lop/bin/activate
pip3 install --no-index --upgrade pip
pip3 install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cpu
pip3 install -r requirements.txt
pip3 install -e .
4 changes: 2 additions & 2 deletions lop/algos/bp.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def learn(self, x, target):
def perturb(self):
with torch.no_grad():
for i in range(int(len(self.net.layers)/2)+1):
self.net.layers[i * 2].bias -=-\
self.net.layers[i * 2].bias +=\
torch.empty(self.net.layers[i * 2].bias.shape, device=self.device).normal_(mean=0, std=self.perturb_scale)
self.net.layers[i * 2].weight -=-\
self.net.layers[i * 2].weight +=\
torch.empty(self.net.layers[i * 2].weight.shape, device=self.device).normal_(mean=0, std=self.perturb_scale)
16 changes: 8 additions & 8 deletions lop/algos/convGnT.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,16 @@ def update_utility(self, layer_idx=0, features=None):
self.mean_abs_feature_act[layer_idx] *= self.decay_rate
if isinstance(current_layer, Linear):
input_wight_mag = current_layer.weight.data.abs().mean(dim=1)
self.mean_feature_act[layer_idx] -=- (1 - self.decay_rate) * features.mean(dim=0)
self.mean_abs_feature_act[layer_idx] -=- (1 - self.decay_rate) * features.abs().mean(dim=0)
self.mean_feature_act[layer_idx] += (1 - self.decay_rate) * features.mean(dim=0)
self.mean_abs_feature_act[layer_idx] += (1 - self.decay_rate) * features.abs().mean(dim=0)
elif isinstance(current_layer, Conv2d):
input_wight_mag = current_layer.weight.data.abs().mean(dim=(1, 2, 3))
if isinstance(next_layer, Conv2d):
self.mean_feature_act[layer_idx] -=- (1 - self.decay_rate) * features.mean(dim=(0, 2, 3))
self.mean_abs_feature_act[layer_idx] -=- (1 - self.decay_rate) * features.abs().mean(dim=(0, 2, 3))
self.mean_feature_act[layer_idx] += (1 - self.decay_rate) * features.mean(dim=(0, 2, 3))
self.mean_abs_feature_act[layer_idx] += (1 - self.decay_rate) * features.abs().mean(dim=(0, 2, 3))
else:
self.mean_feature_act[layer_idx] -=- (1 - self.decay_rate) * features.mean(dim=0).view(-1, self.num_last_filter_outputs).mean(dim=1)
self.mean_abs_feature_act[layer_idx] -=- (1 - self.decay_rate) * features.abs().mean(dim=0).view(-1, self.num_last_filter_outputs).mean(dim=1)
self.mean_feature_act[layer_idx] += (1 - self.decay_rate) * features.mean(dim=0).view(-1, self.num_last_filter_outputs).mean(dim=1)
self.mean_abs_feature_act[layer_idx] += (1 - self.decay_rate) * features.abs().mean(dim=0).view(-1, self.num_last_filter_outputs).mean(dim=1)

bias_corrected_act = self.mean_feature_act[layer_idx] / bias_correction

Expand Down Expand Up @@ -137,7 +137,7 @@ def update_utility(self, layer_idx=0, features=None):
if self.util_type == 'random':
self.bias_corrected_util[layer_idx] = rand(self.util[layer_idx].shape)
else:
self.util[layer_idx] -=- (1 - self.decay_rate) * new_util
self.util[layer_idx] += (1 - self.decay_rate) * new_util
# correct the bias in the utility computation
self.bias_corrected_util[layer_idx] = self.util[layer_idx] / bias_correction

Expand Down Expand Up @@ -166,7 +166,7 @@ def test_features(self, features):
eligible_feature_indices = where(self.ages[i] > self.maturity_threshold)[0]
if eligible_feature_indices.shape[0] == 0:
continue
self.accumulated_num_features_to_replace[i] -=- self.num_new_features_to_replace[i]
self.accumulated_num_features_to_replace[i] += self.num_new_features_to_replace[i]

"""
Case when the number of features to be replaced is between 0 and 1.
Expand Down
6 changes: 3 additions & 3 deletions lop/algos/gnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def update_utility(self, layer_idx=0, features=None, next_features=None):
else:
new_util = 0

self.util[layer_idx] -=- (1 - self.decay_rate) * new_util
self.util[layer_idx] += (1 - self.decay_rate) * new_util

"""
Adam-style bias correction
Expand Down Expand Up @@ -190,14 +190,14 @@ def gen_new_features(self, features_to_replace, num_features_to_replace):
next_layer = self.net[i * 2 + 2]
current_layer.weight.data[features_to_replace[i], :] *= 0.0
# noinspection PyArgumentList
current_layer.weight.data[features_to_replace[i], :] -=- \
current_layer.weight.data[features_to_replace[i], :] += \
torch.empty(num_features_to_replace[i], current_layer.in_features).uniform_(
-self.bounds[i], self.bounds[i]).to(self.device)
current_layer.bias.data[features_to_replace[i]] *= 0
"""
# Update bias to correct for the removed features and set the outgoing weights and ages to zero
"""
next_layer.bias.data -=- (next_layer.weight.data[:, features_to_replace[i]] * \
next_layer.bias.data += (next_layer.weight.data[:, features_to_replace[i]] * \
self.mean_feature_act[i][features_to_replace[i]] / \
(1 - self.decay_rate ** self.ages[i][features_to_replace[i]])).sum(dim=1)
next_layer.weight.data[:, features_to_replace[i]] = 0
Expand Down
86 changes: 86 additions & 0 deletions lop/algos/gntRedo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import torch
from math import sqrt


class GnTredo(object):
"""
Generate-and-Test algorithm for feed forward neural networks, based on ReDo
"""
def __init__(self, net, hidden_activation, threshold=0.01, init='kaiming', device="cpu", reset_period=1000):
super(GnTredo, self).__init__()
self.device = device
self.net = net
self.num_hidden_layers = int(len(self.net)/2)
self.threshold = threshold
self.steps_since_last_redo = 0
self.reset_period = reset_period
# Calculate uniform distribution's bound for random feature initialization
if hidden_activation == 'selu': init = 'lecun'
self.bounds = self.compute_bounds(hidden_activation=hidden_activation, init=init)

def compute_bounds(self, hidden_activation, init='kaiming'):
if hidden_activation in ['swish', 'elu']: hidden_activation = 'relu'
if init == 'default':
bounds = [sqrt(1 / self.net[i * 2].in_features) for i in range(self.num_hidden_layers)]
elif init == 'xavier':
bounds = [torch.nn.init.calculate_gain(nonlinearity=hidden_activation) *
sqrt(6 / (self.net[i * 2].in_features + self.net[i * 2].out_features)) for i in
range(self.num_hidden_layers)]
elif init == 'lecun':
bounds = [sqrt(3 / self.net[i * 2].in_features) for i in range(self.num_hidden_layers)]
else:
bounds = [torch.nn.init.calculate_gain(nonlinearity=hidden_activation) *
sqrt(3 / self.net[i * 2].in_features) for i in range(self.num_hidden_layers)]
bounds.append(1 * sqrt(3 / self.net[self.num_hidden_layers * 2].in_features))
return bounds

def units_to_replace(self, features):
"""
Args:
features: Activation values in the neural network, mini-batch * layer-idx * feature-idx
Returns:
Features to replace in each layer, Number of features to replace in each layer
"""
features = features.mean(dim=0)
features_to_replace = [None]*self.num_hidden_layers
num_features_to_replace = [None]*self.num_hidden_layers
for i in range(self.num_hidden_layers):
# Find features to replace
feature_utility = features[i] / features[i].mean()
new_features_to_replace = (feature_utility <= self.threshold).nonzero().reshape(-1)
# Initialize utility for new features
features_to_replace[i] = new_features_to_replace
num_features_to_replace[i] = new_features_to_replace.shape[0]

return features_to_replace, num_features_to_replace

def gen_new_features(self, features_to_replace, num_features_to_replace):
"""
Generate new features: Reset input and output weights for low utility features
"""
with torch.no_grad():
for i in range(self.num_hidden_layers):
if num_features_to_replace[i] == 0:
continue
current_layer = self.net[i * 2]
next_layer = self.net[i * 2 + 2]
current_layer.weight.data[features_to_replace[i], :] *= 0.0
current_layer.weight.data[features_to_replace[i], :] += \
torch.empty(num_features_to_replace[i], current_layer.in_features).uniform_(
-self.bounds[i], self.bounds[i]).to(self.device)
current_layer.bias.data[features_to_replace[i]] *= 0

next_layer.weight.data[:, features_to_replace[i]] = 0

def gen_and_test(self, features_history):
"""
Perform generate-and-test
:param features: activation of hidden units in the neural network
"""
self.steps_since_last_redo += 1
if self.steps_since_last_redo < self.reset_period:
return

features_to_replace, num_features_to_replace = self.units_to_replace(features=features_history.abs())
self.gen_new_features(features_to_replace, num_features_to_replace)
self.steps_since_last_redo = 0
2 changes: 1 addition & 1 deletion lop/algos/res_gnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_features(self, features):
eligible_feature_indices = where(self.ages[i] > self.maturity_threshold)[0]
if eligible_feature_indices.shape[0] == 0:
continue
self.accumulated_num_features_to_replace[i] -=- self.num_new_features_to_replace[i]
self.accumulated_num_features_to_replace[i] += self.num_new_features_to_replace[i]

"""
Case when the number of features to be replaced is between 0 and 1.
Expand Down
Empty file added lop/algos/rl/__init__.py
Empty file.
32 changes: 32 additions & 0 deletions lop/algos/rl/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch


class Agent(object):
def __init__(self, pol, learner, device='cpu', to_log_features=False):
self.pol = pol
self.learner = learner
self.device = device
self.to_log_features = to_log_features

def get_action(self, o):
"""
:param o: np. array of shape (1,)
:return: a two tuple
- np.array of shape (1,)
- np.array of shape (1,)
"""
action, lprob, dist = self.pol.action(torch.tensor(o, dtype=torch.float32, device=self.device).unsqueeze(0),
to_log_features=self.to_log_features)
features = None
if self.to_log_features:
features = self.pol.get_activations()
return action[0].cpu().numpy(), lprob.cpu().numpy(), self.pol.dist_to(dist, to_device='cpu'), features

def log_update(self, o, a, r, op, logp, dist, done):
return self.learner.log_update(o, a, r, op, logp, dist, done)

def preprocess_state(self, state):
return state

def choose_action(self, o, epsilon):
return self.get_action(o=o)[0]
52 changes: 52 additions & 0 deletions lop/algos/rl/buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import numpy as np
import collections as c

import torch


class Buffer(object):
def __init__(self, o_dim, a_dim, bs, device='cpu'):
self.o_dim = o_dim
self.a_dim = a_dim
self.bs = bs
self.device = device
self.o_buf, self.a_buf, self.r_buf, self.logpb_buf, self.distb_buf, self.done_buf = \
c.deque(), c.deque(), c.deque(), c.deque(), c.deque(), c.deque()
self.op = np.zeros((1, o_dim), dtype=np.float32)

def store(self, o, a, r, op, logpb, dist, done):
self.o_buf.append(o)
self.a_buf.append(a)
self.r_buf.append(r)
self.logpb_buf.append(logpb)
self.distb_buf.append(dist)
self.done_buf.append(float(done))
self.op[:] = op

def pop(self):
self.o_buf.popleft()
self.a_buf.popleft()
self.r_buf.popleft()
self.logpb_buf.popleft()
self.distb_buf.popleft()
self.done_buf.popleft()

def clear(self):
self.o_buf.clear()
self.a_buf.clear()
self.r_buf.clear()
self.logpb_buf.clear()
self.distb_buf.clear()
self.done_buf.clear()

def get(self, dist_stack):
rang = range(self.bs)
os = torch.as_tensor(np.array([self.o_buf[i] for i in rang]), dtype=torch.float32, device=self.device).view(-1, self.o_dim)
acts = torch.as_tensor(np.array([self.a_buf[i] for i in rang]), dtype=torch.float32, device=self.device).view(-1, self.a_dim)
rs = torch.as_tensor(np.array([self.r_buf[i] for i in rang]), dtype=torch.float32, device=self.device).view(-1, 1)
op = torch.as_tensor(self.op, device=self.device).view(-1, self.o_dim)
logpbs = torch.as_tensor(np.array([self.logpb_buf[i] for i in rang]), dtype=torch.float32, device=self.device).view(-1, 1)
distbs = dist_stack([self.distb_buf[i] for i in rang], device=self.device)
dones = torch.as_tensor(np.array([self.done_buf[i] for i in rang]), dtype=torch.float32, device=self.device).view(-1, 1)

return os, acts, rs, op, logpbs, distbs, dones
22 changes: 22 additions & 0 deletions lop/algos/rl/learner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
class Learner(object):
def log_update(self, o, a, r, op, logpb, dist, done):
self.log(o, a, r, op, logpb, dist, done)
info0 = {'learned': False}
if self.learn_time(done):
info = self.learn()
self.post_learn()
info0.update(info)
info0['learned'] = True
return info0

def log(self, o, a, r, op, logpb, dist, done):
pass

def learn_time(self, done):
pass

def post_learn(self):
pass

def learn(self, env=None):
pass
Loading

0 comments on commit bdfe0dd

Please sign in to comment.