Skip to content

Commit

Permalink
autocast support for a2c and ppo
Browse files Browse the repository at this point in the history
  • Loading branch information
realiti4 committed Aug 10, 2023
1 parent 41fa735 commit c5ce6bb
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 57 deletions.
7 changes: 4 additions & 3 deletions tianshou/data/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,10 @@ def collect(
self.data.update(act=act_sample)
else:
if no_grad:
with torch.no_grad(): # faster than retain_grad version
# self.data.obs will be used by agent to get result
result = self.policy(self.data, last_state)
with torch.autocast("cuda", enabled=self.policy.use_autocast):
with torch.no_grad(): # faster than retain_grad version
# self.data.obs will be used by agent to get result
result = self.policy(self.data, last_state)
else:
result = self.policy(self.data, last_state)
# update state / act / policy into self.data
Expand Down
5 changes: 5 additions & 0 deletions tianshou/policy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from gymnasium.spaces import Box, Discrete, MultiBinary, MultiDiscrete
from numba import njit
from torch import nn
from torch.cuda.amp import GradScaler

from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
from tianshou.utils import MultipleLRSchedulers
Expand Down Expand Up @@ -67,6 +68,7 @@ def __init__(
action_bound_method: str = "",
lr_scheduler: Optional[Union[torch.optim.lr_scheduler.LambdaLR,
MultipleLRSchedulers]] = None,
use_autocast: bool = False,
) -> None:
super().__init__()
self.observation_space = observation_space
Expand All @@ -85,6 +87,9 @@ def __init__(
self.lr_scheduler = lr_scheduler
self._compile()

self.scaler = GradScaler(growth_interval=1000, enabled=use_autocast)
self.use_autocast = use_autocast

def set_agent_id(self, agent_id: int) -> None:
"""Set self.agent_id = agent_id, for MARL."""
self.agent_id = agent_id
Expand Down
46 changes: 28 additions & 18 deletions tianshou/policy/modelfree/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,11 @@ def _compute_returns(
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
) -> Batch:
v_s, v_s_ = [], []
with torch.no_grad():
for minibatch in batch.split(self._batch, shuffle=False, merge_last=True):
v_s.append(self.critic(minibatch.obs))
v_s_.append(self.critic(minibatch.obs_next))
with torch.autocast('cuda', enabled=self.use_autocast):
with torch.no_grad():
for minibatch in batch.split(self._batch, shuffle=False, merge_last=True):
v_s.append(self.critic(minibatch.obs).float())
v_s_.append(self.critic(minibatch.obs_next).float())
batch.v_s = torch.cat(v_s, dim=0).flatten() # old value
v_s = batch.v_s.cpu().numpy()
v_s_ = torch.cat(v_s_, dim=0).flatten().cpu().numpy()
Expand Down Expand Up @@ -123,25 +124,34 @@ def learn( # type: ignore
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
for _ in range(repeat):
for minibatch in batch.split(batch_size, merge_last=True):
# calculate loss for actor
dist = self(minibatch).dist
log_prob = dist.log_prob(minibatch.act)
log_prob = log_prob.reshape(len(minibatch.adv), -1).transpose(0, 1)
actor_loss = -(log_prob * minibatch.adv).mean()
# calculate loss for critic
value = self.critic(minibatch.obs).flatten()
vf_loss = F.mse_loss(minibatch.returns, value)
# calculate regularization and overall loss
ent_loss = dist.entropy().mean()
loss = actor_loss + self._weight_vf * vf_loss \
- self._weight_ent * ent_loss
with torch.autocast('cuda', enabled=self.use_autocast):
# calculate loss for actor
dist = self(minibatch).dist
log_prob = dist.log_prob(minibatch.act)
log_prob = log_prob.reshape(len(minibatch.adv), -1).transpose(0, 1)
actor_loss = -(log_prob * minibatch.adv).mean()
# calculate loss for critic
value = self.critic(minibatch.obs).flatten()
vf_loss = F.mse_loss(minibatch.returns, value)
# calculate regularization and overall loss
ent_loss = dist.entropy().mean()
loss = actor_loss + self._weight_vf * vf_loss \
- self._weight_ent * ent_loss

self.optim.zero_grad()
loss.backward()

self.scaler.scale(loss).backward()
# loss.backward()
if self._grad_norm: # clip large gradient
self.scaler.unscale_(self.optim)
nn.utils.clip_grad_norm_(
self._actor_critic.parameters(), max_norm=self._grad_norm
)
self.optim.step()

# self.optim.step()
self.scaler.step(self.optim)
self.scaler.update()

actor_losses.append(actor_loss.item())
vf_losses.append(vf_loss.item())
ent_losses.append(ent_loss.item())
Expand Down
80 changes: 45 additions & 35 deletions tianshou/policy/modelfree/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,46 +103,56 @@ def learn( # type: ignore
if self._recompute_adv and step > 0:
batch = self._compute_returns(batch, self._buffer, self._indices)
for minibatch in batch.split(batch_size, merge_last=True):
# calculate loss for actor
dist = self(minibatch).dist
if self._norm_adv:
mean, std = minibatch.adv.mean(), minibatch.adv.std()
minibatch.adv = (minibatch.adv -
mean) / (std + self._eps) # per-batch norm
ratio = (dist.log_prob(minibatch.act) -
minibatch.logp_old).exp().float()
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
surr1 = ratio * minibatch.adv
surr2 = ratio.clamp(
1.0 - self._eps_clip, 1.0 + self._eps_clip
) * minibatch.adv
if self._dual_clip:
clip1 = torch.min(surr1, surr2)
clip2 = torch.max(clip1, self._dual_clip * minibatch.adv)
clip_loss = -torch.where(minibatch.adv < 0, clip2, clip1).mean()
else:
clip_loss = -torch.min(surr1, surr2).mean()
# calculate loss for critic
value = self.critic(minibatch.obs).flatten()
if self._value_clip:
v_clip = minibatch.v_s + \
(value - minibatch.v_s).clamp(-self._eps_clip, self._eps_clip)
vf1 = (minibatch.returns - value).pow(2)
vf2 = (minibatch.returns - v_clip).pow(2)
vf_loss = torch.max(vf1, vf2).mean()
else:
vf_loss = (minibatch.returns - value).pow(2).mean()
# calculate regularization and overall loss
ent_loss = dist.entropy().mean()
loss = clip_loss + self._weight_vf * vf_loss \
- self._weight_ent * ent_loss
with torch.autocast('cuda', enabled=self.use_autocast):
# calculate loss for actor
dist = self(minibatch).dist
if self._norm_adv:
mean, std = minibatch.adv.mean(), minibatch.adv.std()
minibatch.adv = (minibatch.adv -
mean) / (std + self._eps) # per-batch norm
ratio = (dist.log_prob(minibatch.act) -
minibatch.logp_old).exp().float()
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
surr1 = ratio * minibatch.adv
surr2 = ratio.clamp(
1.0 - self._eps_clip, 1.0 + self._eps_clip
) * minibatch.adv
if self._dual_clip:
clip1 = torch.min(surr1, surr2)
clip2 = torch.max(clip1, self._dual_clip * minibatch.adv)
clip_loss = -torch.where(minibatch.adv < 0, clip2, clip1).mean()
else:
clip_loss = -torch.min(surr1, surr2).mean()
# calculate loss for critic
value = self.critic(minibatch.obs).flatten()
if self._value_clip:
v_clip = minibatch.v_s + \
(value - minibatch.v_s).clamp(-self._eps_clip, self._eps_clip)
vf1 = (minibatch.returns - value).pow(2)
vf2 = (minibatch.returns - v_clip).pow(2)
vf_loss = torch.max(vf1, vf2).mean()
else:
vf_loss = (minibatch.returns - value).pow(2).mean()
# calculate regularization and overall loss
ent_loss = dist.entropy().mean()
loss = clip_loss + self._weight_vf * vf_loss \
- self._weight_ent * ent_loss

self.optim.zero_grad()
loss.backward()

self.scaler.scale(loss).backward()
# loss.backward()

if self._grad_norm: # clip large gradient
self.scaler.unscale_(self.optim)
nn.utils.clip_grad_norm_(
self._actor_critic.parameters(), max_norm=self._grad_norm
)
self.optim.step()

# self.optim.step()
self.scaler.step(self.optim)
self.scaler.update()

clip_losses.append(clip_loss.item())
vf_losses.append(vf_loss.item())
ent_losses.append(ent_loss.item())
Expand Down
2 changes: 1 addition & 1 deletion tianshou/trainer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def reset(self) -> None:
elif self.test_collector is None:
self.test_in_train = False

if self.test_collector is not None:
if self.test_collector is not None and self.epoch > 0: # Dont test before training, dont waste time
assert self.episode_per_test is not None
assert not isinstance(self.test_collector, AsyncCollector) # Issue 700
self.test_collector.reset_stat()
Expand Down

0 comments on commit c5ce6bb

Please sign in to comment.