Skip to content

Commit

Permalink
update docstr
Browse files Browse the repository at this point in the history
  • Loading branch information
Haichao-Zhang committed Dec 19, 2023
1 parent 944496e commit 82c59d1
Showing 1 changed file with 25 additions and 108 deletions.
133 changes: 25 additions & 108 deletions alf/algorithms/smodice_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,83 +49,6 @@
SmoLossInfo = namedtuple("SmoLossInfo", ["actor"], default_value=())


# -> algorithm
class Discriminator_SA(Algorithm):
def __init__(self, observation_spec, action_spec):
super().__init__(observation_spec=observation_spec)

disc_net = CriticNetwork((observation_spec, action_spec))
self._disc_net = disc_net

def forward(self, inputs, state=()):
return self._disc_net(inputs, state)

def compute_grad_pen(self, expert_state, offline_state, lambda_=10):
alpha = torch.rand(expert_state.size(0), 1)
expert_data = expert_state
offline_data = offline_state

alpha = alpha.expand_as(expert_data).to(expert_data.device)

mixup_data = alpha * expert_data + (1 - alpha) * offline_data
mixup_data.requires_grad = True

disc = self(mixup_data)
ones = torch.ones(disc.size()).to(disc.device)
grad = autograd.grad(
outputs=disc,
inputs=mixup_data,
grad_outputs=ones,
create_graph=True,
retain_graph=True,
only_inputs=True)[0]

grad_pen = lambda_ * (grad.norm(2, dim=1) - 1).pow(2).mean()
return grad_pen

def update(self, expert_loader, offline_loader):
self.train()

loss = 0
n = 0
for expert_state, offline_state in zip(expert_loader, offline_loader):

expert_state = expert_state[0].to(self.device)
offline_state = offline_state[0][:expert_state.shape[0]].to(
self.device)

policy_d = self(offline_state)
expert_d = self(expert_state)

expert_loss = F.binary_cross_entropy_with_logits(
expert_d,
torch.ones(expert_d.size()).to(self.device))
policy_loss = F.binary_cross_entropy_with_logits(
policy_d,
torch.zeros(policy_d.size()).to(self.device))

gail_loss = expert_loss + policy_loss
grad_pen = self.compute_grad_pen(expert_state, offline_state)

loss += (gail_loss + grad_pen).item()
n += 1

self.optimizer.zero_grad()
(gail_loss + grad_pen).backward()
self.optimizer.step()
return loss / n

def predict_reward(self, state):
with torch.no_grad():
self.eval()
d = self(state)
s = torch.sigmoid(d)
# log(d^E/d^O)
# reward = - (1/s-1).log()
reward = s.log() - (1 - s).log()
return reward


@alf.configurable
class SmodiceAlgorithm(OffPolicyAlgorithm):
r"""SMODICE algorithm.
Expand All @@ -143,27 +66,24 @@ class SmodiceAlgorithm(OffPolicyAlgorithm):
ICML 2022.
"""

def __init__(
self,
observation_spec,
action_spec: BoundedTensorSpec,
reward_spec=TensorSpec(()),
actor_network_cls=ActorNetwork,
v_network_cls=ValueNetwork,
discriminator_network_cls=None,
actor_optimizer=None,
value_optimizer=None,
discriminator_optimizer=None,
#=====new params
gamma: float = 0.99,
v_l2_reg: float = 0.001,
env=None,
config: TrainerConfig = None,
checkpoint=None,
debug_summaries=False,
epsilon_greedy=None,
f="chi",
name="SmodiceAlgorithm"):
def __init__(self,
observation_spec,
action_spec: BoundedTensorSpec,
reward_spec=TensorSpec(()),
actor_network_cls=ActorNetwork,
v_network_cls=ValueNetwork,
discriminator_network_cls=None,
actor_optimizer=None,
value_optimizer=None,
discriminator_optimizer=None,
gamma: float = 0.99,
f="chi",
env=None,
config: TrainerConfig = None,
checkpoint=None,
debug_summaries=False,
epsilon_greedy=None,
name="SmodiceAlgorithm"):
"""
Args:
observation_spec (nested TensorSpec): representing the observations.
Expand All @@ -178,7 +98,13 @@ def __init__(
actor_network_cls (Callable): is used to construct the actor network.
The constructed actor network is a determinstic network and
will be used to generate continuous actions.
v_network_cls (Callable): is used to construct the value network.
discriminator_network_cls (Callable): is used to construct the discriminatr.
actor_optimizer (torch.optim.optimizer): The optimizer for actor.
value_optimizer (torch.optim.optimizer): The optimizer for value network.
discriminator_optimizer (torch.optim.optimizer): The optimizer for discriminator.
gamma (float): the discount factor.
f (str): the function form for f-divergence. Currently support 'chi' and 'kl'
env (Environment): The environment to interact with. ``env`` is a
batched environment, which means that it runs multiple simulations
simultateously. ``env` only needs to be provided to the root
Expand Down Expand Up @@ -242,12 +168,9 @@ def __init__(
if discriminator_optimizer is not None and discriminator_net is not None:
self.add_optimizer(discriminator_optimizer, [discriminator_net])

self._actor_optimizer = actor_optimizer
self._value_optimizer = value_optimizer
self._v_l2_reg = v_l2_reg
self._gamma = gamma
self._f = f
assert f == "chi", "only support chi form"
assert f in ["chi", "kl"], "only support chi or kl form"

# f-divergence functions
if self._f == 'chi':
Expand Down Expand Up @@ -327,14 +250,8 @@ def _discriminator_train_step(self, inputs: TimeStep, state, rollout_info,
return LossInfo(loss=expert_loss, extra=SmoLossInfo(actor=expert_loss))

def value_train_step(self, inputs: TimeStep, state, rollout_info):
# initial_v_values, e_v, result={}
observation = inputs.observation

# extract initial observation from batch, or prepare a batch
initial_observation = observation

# Shared network values
# mini_batch_length
initial_v_values, _ = self._value_network(initial_observation)

# mini-batch len
Expand Down

0 comments on commit 82c59d1

Please sign in to comment.