From 82c59d156b1366d1a796a13e427bcc3139e34563 Mon Sep 17 00:00:00 2001 From: Haichao Zhang Date: Mon, 18 Dec 2023 21:58:42 -0800 Subject: [PATCH] update docstr --- alf/algorithms/smodice_algorithm.py | 133 ++++++---------------------- 1 file changed, 25 insertions(+), 108 deletions(-) diff --git a/alf/algorithms/smodice_algorithm.py b/alf/algorithms/smodice_algorithm.py index 8a60dec59..4f0f49ca3 100644 --- a/alf/algorithms/smodice_algorithm.py +++ b/alf/algorithms/smodice_algorithm.py @@ -49,83 +49,6 @@ SmoLossInfo = namedtuple("SmoLossInfo", ["actor"], default_value=()) -# -> algorithm -class Discriminator_SA(Algorithm): - def __init__(self, observation_spec, action_spec): - super().__init__(observation_spec=observation_spec) - - disc_net = CriticNetwork((observation_spec, action_spec)) - self._disc_net = disc_net - - def forward(self, inputs, state=()): - return self._disc_net(inputs, state) - - def compute_grad_pen(self, expert_state, offline_state, lambda_=10): - alpha = torch.rand(expert_state.size(0), 1) - expert_data = expert_state - offline_data = offline_state - - alpha = alpha.expand_as(expert_data).to(expert_data.device) - - mixup_data = alpha * expert_data + (1 - alpha) * offline_data - mixup_data.requires_grad = True - - disc = self(mixup_data) - ones = torch.ones(disc.size()).to(disc.device) - grad = autograd.grad( - outputs=disc, - inputs=mixup_data, - grad_outputs=ones, - create_graph=True, - retain_graph=True, - only_inputs=True)[0] - - grad_pen = lambda_ * (grad.norm(2, dim=1) - 1).pow(2).mean() - return grad_pen - - def update(self, expert_loader, offline_loader): - self.train() - - loss = 0 - n = 0 - for expert_state, offline_state in zip(expert_loader, offline_loader): - - expert_state = expert_state[0].to(self.device) - offline_state = offline_state[0][:expert_state.shape[0]].to( - self.device) - - policy_d = self(offline_state) - expert_d = self(expert_state) - - expert_loss = F.binary_cross_entropy_with_logits( - expert_d, - torch.ones(expert_d.size()).to(self.device)) - policy_loss = F.binary_cross_entropy_with_logits( - policy_d, - torch.zeros(policy_d.size()).to(self.device)) - - gail_loss = expert_loss + policy_loss - grad_pen = self.compute_grad_pen(expert_state, offline_state) - - loss += (gail_loss + grad_pen).item() - n += 1 - - self.optimizer.zero_grad() - (gail_loss + grad_pen).backward() - self.optimizer.step() - return loss / n - - def predict_reward(self, state): - with torch.no_grad(): - self.eval() - d = self(state) - s = torch.sigmoid(d) - # log(d^E/d^O) - # reward = - (1/s-1).log() - reward = s.log() - (1 - s).log() - return reward - - @alf.configurable class SmodiceAlgorithm(OffPolicyAlgorithm): r"""SMODICE algorithm. @@ -143,27 +66,24 @@ class SmodiceAlgorithm(OffPolicyAlgorithm): ICML 2022. """ - def __init__( - self, - observation_spec, - action_spec: BoundedTensorSpec, - reward_spec=TensorSpec(()), - actor_network_cls=ActorNetwork, - v_network_cls=ValueNetwork, - discriminator_network_cls=None, - actor_optimizer=None, - value_optimizer=None, - discriminator_optimizer=None, - #=====new params - gamma: float = 0.99, - v_l2_reg: float = 0.001, - env=None, - config: TrainerConfig = None, - checkpoint=None, - debug_summaries=False, - epsilon_greedy=None, - f="chi", - name="SmodiceAlgorithm"): + def __init__(self, + observation_spec, + action_spec: BoundedTensorSpec, + reward_spec=TensorSpec(()), + actor_network_cls=ActorNetwork, + v_network_cls=ValueNetwork, + discriminator_network_cls=None, + actor_optimizer=None, + value_optimizer=None, + discriminator_optimizer=None, + gamma: float = 0.99, + f="chi", + env=None, + config: TrainerConfig = None, + checkpoint=None, + debug_summaries=False, + epsilon_greedy=None, + name="SmodiceAlgorithm"): """ Args: observation_spec (nested TensorSpec): representing the observations. @@ -178,7 +98,13 @@ def __init__( actor_network_cls (Callable): is used to construct the actor network. The constructed actor network is a determinstic network and will be used to generate continuous actions. + v_network_cls (Callable): is used to construct the value network. + discriminator_network_cls (Callable): is used to construct the discriminatr. actor_optimizer (torch.optim.optimizer): The optimizer for actor. + value_optimizer (torch.optim.optimizer): The optimizer for value network. + discriminator_optimizer (torch.optim.optimizer): The optimizer for discriminator. + gamma (float): the discount factor. + f (str): the function form for f-divergence. Currently support 'chi' and 'kl' env (Environment): The environment to interact with. ``env`` is a batched environment, which means that it runs multiple simulations simultateously. ``env` only needs to be provided to the root @@ -242,12 +168,9 @@ def __init__( if discriminator_optimizer is not None and discriminator_net is not None: self.add_optimizer(discriminator_optimizer, [discriminator_net]) - self._actor_optimizer = actor_optimizer - self._value_optimizer = value_optimizer - self._v_l2_reg = v_l2_reg self._gamma = gamma self._f = f - assert f == "chi", "only support chi form" + assert f in ["chi", "kl"], "only support chi or kl form" # f-divergence functions if self._f == 'chi': @@ -327,14 +250,8 @@ def _discriminator_train_step(self, inputs: TimeStep, state, rollout_info, return LossInfo(loss=expert_loss, extra=SmoLossInfo(actor=expert_loss)) def value_train_step(self, inputs: TimeStep, state, rollout_info): - # initial_v_values, e_v, result={} observation = inputs.observation - - # extract initial observation from batch, or prepare a batch initial_observation = observation - - # Shared network values - # mini_batch_length initial_v_values, _ = self._value_network(initial_observation) # mini-batch len