Skip to content

Commit

Permalink
update docstr
Browse files Browse the repository at this point in the history
  • Loading branch information
Haichao-Zhang committed Dec 19, 2023
1 parent 944496e commit 6763d36
Showing 1 changed file with 25 additions and 31 deletions.
56 changes: 25 additions & 31 deletions alf/algorithms/smodice_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,27 +143,24 @@ class SmodiceAlgorithm(OffPolicyAlgorithm):
ICML 2022.
"""

def __init__(
self,
observation_spec,
action_spec: BoundedTensorSpec,
reward_spec=TensorSpec(()),
actor_network_cls=ActorNetwork,
v_network_cls=ValueNetwork,
discriminator_network_cls=None,
actor_optimizer=None,
value_optimizer=None,
discriminator_optimizer=None,
#=====new params
gamma: float = 0.99,
v_l2_reg: float = 0.001,
env=None,
config: TrainerConfig = None,
checkpoint=None,
debug_summaries=False,
epsilon_greedy=None,
f="chi",
name="SmodiceAlgorithm"):
def __init__(self,
observation_spec,
action_spec: BoundedTensorSpec,
reward_spec=TensorSpec(()),
actor_network_cls=ActorNetwork,
v_network_cls=ValueNetwork,
discriminator_network_cls=None,
actor_optimizer=None,
value_optimizer=None,
discriminator_optimizer=None,
gamma: float = 0.99,
f="chi",
env=None,
config: TrainerConfig = None,
checkpoint=None,
debug_summaries=False,
epsilon_greedy=None,
name="SmodiceAlgorithm"):
"""
Args:
observation_spec (nested TensorSpec): representing the observations.
Expand All @@ -178,7 +175,13 @@ def __init__(
actor_network_cls (Callable): is used to construct the actor network.
The constructed actor network is a determinstic network and
will be used to generate continuous actions.
v_network_cls (Callable): is used to construct the value network.
discriminator_network_cls (Callable): is used to construct the discriminatr.
actor_optimizer (torch.optim.optimizer): The optimizer for actor.
value_optimizer (torch.optim.optimizer): The optimizer for value network.
discriminator_optimizer (torch.optim.optimizer): The optimizer for discriminator.
gamma (float): the discount factor.
f (str): the function form for f-divergence. Currently support 'chi' and 'kl'
env (Environment): The environment to interact with. ``env`` is a
batched environment, which means that it runs multiple simulations
simultateously. ``env` only needs to be provided to the root
Expand Down Expand Up @@ -242,12 +245,9 @@ def __init__(
if discriminator_optimizer is not None and discriminator_net is not None:
self.add_optimizer(discriminator_optimizer, [discriminator_net])

self._actor_optimizer = actor_optimizer
self._value_optimizer = value_optimizer
self._v_l2_reg = v_l2_reg
self._gamma = gamma
self._f = f
assert f == "chi", "only support chi form"
assert f in ["chi", "kl"], "only support chi or kl form"

# f-divergence functions
if self._f == 'chi':
Expand Down Expand Up @@ -327,14 +327,8 @@ def _discriminator_train_step(self, inputs: TimeStep, state, rollout_info,
return LossInfo(loss=expert_loss, extra=SmoLossInfo(actor=expert_loss))

def value_train_step(self, inputs: TimeStep, state, rollout_info):
# initial_v_values, e_v, result={}
observation = inputs.observation

# extract initial observation from batch, or prepare a batch
initial_observation = observation

# Shared network values
# mini_batch_length
initial_v_values, _ = self._value_network(initial_observation)

# mini-batch len
Expand Down

0 comments on commit 6763d36

Please sign in to comment.