Skip to content

Commit

Permalink
Add gradient penalty
Browse files Browse the repository at this point in the history
  • Loading branch information
Haichao-Zhang committed Jan 9, 2024
1 parent 679aa11 commit 7065cda
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 12 deletions.
5 changes: 4 additions & 1 deletion alf/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2132,7 +2132,10 @@ def _hybrid_update(self, experience, batch_info, offline_experience,
else:
loss_info = offline_loss_info

params = self._backward_and_gradient_update(loss_info.loss * weight)
params, gns = self._backward_and_gradient_update(
loss_info.loss * weight)

loss_info = loss_info._replace(gns=gns)

if self._RL_train:
# for now, there is no need to do a hybrid after update
Expand Down
48 changes: 39 additions & 9 deletions alf/algorithms/smodice_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@
SmoCriticInfo = namedtuple("SmoCriticInfo",
["values", "initial_v_values", "is_first"])

SmoLossInfo = namedtuple("SmoLossInfo", ["actor"], default_value=())
SmoLossInfo = namedtuple(
"SmoLossInfo", ["actor", "grad_penalty"], default_value=())


@alf.configurable
Expand Down Expand Up @@ -77,7 +78,8 @@ def __init__(self,
value_optimizer=None,
discriminator_optimizer=None,
gamma: float = 0.99,
f="chi",
f: str = "chi",
gradient_penalty_weight: float = 1,
env=None,
config: TrainerConfig = None,
checkpoint=None,
Expand All @@ -104,7 +106,8 @@ def __init__(self,
value_optimizer (torch.optim.optimizer): The optimizer for value network.
discriminator_optimizer (torch.optim.optimizer): The optimizer for discriminator.
gamma (float): the discount factor.
f (str): the function form for f-divergence. Currently support 'chi' and 'kl'
f: the function form for f-divergence. Currently support 'chi' and 'kl'
gradient_penalty_weight: the weight for discriminator gradient penalty
env (Environment): The environment to interact with. ``env`` is a
batched environment, which means that it runs multiple simulations
simultateously. ``env` only needs to be provided to the root
Expand Down Expand Up @@ -155,6 +158,7 @@ def __init__(self,
self._actor_network = actor_network
self._value_network = value_network
self._discriminator_net = discriminator_net
self._gradient_penalty_weight = gradient_penalty_weight

assert actor_optimizer is not None
if actor_optimizer is not None and actor_network is not None:
Expand Down Expand Up @@ -236,18 +240,44 @@ def _discriminator_train_step(self, inputs: TimeStep, state, rollout_info,
"""
observation = inputs.observation
action = rollout_info.action
expert_logits, _ = self._discriminator_net((observation, action),
state)

discriminator_inputs = (observation, action)

if is_expert:
# turn on input gradient for gradient penalty in the case of expert data
for e in discriminator_inputs:
e.requires_grad = True

expert_logits, _ = self._discriminator_net(discriminator_inputs, state)

if is_expert:
grads = torch.autograd.grad(
outputs=expert_logits,
inputs=discriminator_inputs,
grad_outputs=torch.ones_like(expert_logits),
create_graph=True,
retain_graph=True,
only_inputs=True)

grad_pen = 0
for g in grads:
grad_pen += self._gradient_penalty_weight * (
g.norm(2, dim=1) - 1).pow(2)

label = torch.ones(expert_logits.size())
# turn on input gradient for gradient penalty in the case of expert data
for e in discriminator_inputs:
e.requires_grad = True
else:
label = torch.zeros(expert_logits.size())
grad_pen = ()

expert_loss = F.binary_cross_entropy_with_logits(
expert_logits, label, reduction='none')

return LossInfo(loss=expert_loss, extra=SmoLossInfo(actor=expert_loss))
return LossInfo(
loss=expert_loss if grad_pen == () else expert_loss + grad_pen,
extra=SmoLossInfo(actor=expert_loss, grad_penalty=grad_pen))

def value_train_step(self, inputs: TimeStep, state, rollout_info):
observation = inputs.observation
Expand Down Expand Up @@ -285,7 +315,7 @@ def train_step(self,
alf.summary.scalar("imitation_loss_online",
actor_loss.loss.mean())
alf.summary.scalar("discriminator_loss_online",
expert_disc_loss.loss.mean())
expert_disc_loss.extra.actor.mean())

# use predicted reward
reward = self.predict_reward(inputs, rollout_info)
Expand All @@ -305,7 +335,6 @@ def train_step_offline(self,
state,
rollout_info,
pre_train=False):

action_dist, new_state = self._predict_action(
inputs.observation, state=state.actor)

Expand All @@ -324,7 +353,8 @@ def train_step_offline(self,
actor_loss.loss.mean())
alf.summary.scalar("discriminator_loss_offline",
expert_disc_loss.loss.mean())

alf.summary.scalar("grad_penalty",
expert_disc_loss.extra.grad_penalty.mean())
# use predicted reward
reward = self.predict_reward(inputs, rollout_info)

Expand Down
2 changes: 1 addition & 1 deletion alf/examples/data_collection_carla_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
# This is an example config file for data collection in CARLA.

# the desired replay buffer size for collection
# 100 is just an example. Should set it to he actual desired size.
# 100 is just an example. Should set it to the actual desired size.
replay_buffer_length = 100

# the desired environment for data collection
Expand Down
9 changes: 8 additions & 1 deletion alf/examples/smodice_bipedal_walker_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

offline_buffer_length = None
offline_buffer_dir = [
"./hybrid_rl/replay_buffer_data/pendulum_replay_buffer_from_sac_10k"
"/home/haichaozhang/data/DATA/sac_bipedal_baseline/train/algorithm/ckpt-80000-replay_buffer"
]

alf.config('Agent', rl_algorithm_cls=SmodiceAlgorithm, optimizer=None)
Expand Down Expand Up @@ -67,4 +67,11 @@
# add weight decay to the v_net following smodice paper
value_optimizer=alf.optimizers.Adam(lr=lr, weight_decay=1e-4),
discriminator_optimizer=alf.optimizers.Adam(lr=lr),
gradient_penalty_weight=0.1,
)

# training config
alf.config(
"TrainerConfig",
offline_buffer_dir=offline_buffer_dir,
offline_buffer_length=offline_buffer_length)
2 changes: 2 additions & 0 deletions alf/examples/smodice_pendulum_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
# add weight decay to the v_net following smodice paper
value_optimizer=alf.optimizers.Adam(lr=lr, weight_decay=1e-4),
discriminator_optimizer=alf.optimizers.Adam(lr=lr),
gradient_penalty_weight=0.1,
)

num_iterations = 1000000
Expand All @@ -91,6 +92,7 @@
rl_train_after_update_steps=0, # joint training
mini_batch_size=256,
mini_batch_length=2,
unroll_length=1,
offline_buffer_dir=offline_buffer_dir,
offline_buffer_length=offline_buffer_length,
num_checkpoints=1,
Expand Down

0 comments on commit 7065cda

Please sign in to comment.