Skip to content

Commit

Permalink
update dflex and warp grad collect scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
imgeorgiev committed Jun 29, 2023
1 parent f90f7a8 commit a59045f
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 16 deletions.
3 changes: 2 additions & 1 deletion scripts/cfg/env/hopper.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name: df_hopper
env_name: HopperEnv

config:
_target_: shac.envs.HopperEnv
render: ${general.render}
device: ${general.device}
num_envs: ${resolve_default:512,${general.num_envs}}
num_envs: 512 # ${resolve_default:512,${general.num_envs}}
seed: ${general.seed}
episode_length: 1000
no_grad: False
Expand Down
12 changes: 7 additions & 5 deletions scripts/grad_collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def main(config: DictConfig):
torch.random.manual_seed(config.general.seed)

# create environment
env = instantiate(config.env)
env = instantiate(config.env.config)

n = env.num_obs
m = env.num_acts
Expand All @@ -31,14 +31,14 @@ def main(config: DictConfig):
losses = []
baseline = []

hh = np.arange(1, config.env.episode_length + 1, h_step)
hh = np.arange(1, config.env.config.episode_length + 1, h_step)
for h in tqdm(hh):
env.clear_grad()
env.reset()

ww = w.clone()
ww.requires_grad_(True)
loss = torch.zeros(config.env.num_envs).to(device)
loss = torch.zeros(config.env.config.num_envs).to(device)

# apply first noisy action
obs, rew, done, info = env.step(ww)
Expand All @@ -65,8 +65,10 @@ def main(config: DictConfig):
zobg = 1 / std**2 * (loss.unsqueeze(1) - loss[0]) * ww
zobgs.append(zobg.detach().cpu().numpy())

filename = "{:}_grads_{:}".format(env.__class__.__name__, config.env.episode_length)
if "warp" in config.env._target_:
filename = "{:}_grads_{:}".format(
env.__class__.__name__, config.env.config.episode_length
)
if "warp" in config.env.config._target_:
filename = "Warp" + filename
filename = f"outputs/grads/{filename}"
if hasattr(env, "start_state"):
Expand Down
8 changes: 4 additions & 4 deletions scripts/grad_collect_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def main(config: DictConfig):

torch.random.manual_seed(config.general.seed)

env = instantiate(config.env)
env = instantiate(config.env.config)

n = env.num_obs
m = env.num_acts
Expand All @@ -36,7 +36,7 @@ def main(config: DictConfig):

ww = w.clone()
ww.requires_grad_(True)
loss = torch.zeros(config.env.num_envs).to(device)
loss = torch.zeros(config.env.config.num_envs).to(device)

# apply first noisy action
obs, rew, done, info = env.step(ww)
Expand All @@ -62,9 +62,9 @@ def main(config: DictConfig):
zobgs.append(zobg.detach().cpu().numpy())

filename = "{:}_grads2_{:}".format(
env.__class__.__name__, config.env.episode_length
env.__class__.__name__, config.env.config.episode_length
)
if "warp" in config.env._target_:
if "warp" in config.env.config._target_:
filename = "Warp" + filename
filename = f"outputs/grads/{filename}"
if hasattr(env, "start_state"):
Expand Down
9 changes: 5 additions & 4 deletions scripts/grad_collect_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import torch
from tqdm import tqdm
from torchviz import make_dot
from shac.envs import DFlexEnv, WarpEnv
from shac.envs import DFlexEnv
from warp.envs import WarpEnv


@hydra.main(version_base="1.2", config_path="cfg", config_name="config.yaml")
Expand Down Expand Up @@ -68,7 +69,7 @@ def policy(obs):
(dpi,) = torch.autograd.grad(policy(obs.detach()).sum(), th)
dpis.append(dpi)
action = policy(obs) + w[t]
obs, rew, done, info = env.step(action)
obs, rew, term, trunc, info = env.step(action)
loss += rew
# NOTE: commented out code below is for the debugging of more efficient grad computation
# make_dot(loss.sum(), show_attrs=True, show_saved=True).render("correct_graph")
Expand Down Expand Up @@ -107,9 +108,9 @@ def policy(obs):

# Save data
filename = "{:}_grads_ms_{:}".format(
env.__class__.__name__, config.env.episode_length
env.__class__.__name__, config.env.config.episode_length
)
if "warp" in config.env._target_:
if "warp" in config.env.config._target_:
filename = "Warp" + filename
filename = f"outputs/grads/{filename}"
if hasattr(env, "start_state"):
Expand Down
6 changes: 4 additions & 2 deletions scripts/grad_collect_multistep_single_theta.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def main(config: DictConfig):
torch.random.manual_seed(config.general.seed)

# create environment
env: Union[DFlexEnv, WarpEnv] = instantiate(config.env)
env: Union[DFlexEnv, WarpEnv] = instantiate(config.env.config)

n = env.num_obs
m = env.num_acts
Expand Down Expand Up @@ -103,7 +103,9 @@ def policy(obs):
zobgs_no_grad.append(zobg_no_grad.detach().cpu().numpy())

np.savez(
"{:}_grads_ms_{:}".format(env.__class__.__name__, config.env.episode_length),
"{:}_grads_ms_{:}".format(
env.__class__.__name__, config.env.config.episode_length
),
zobgs=zobgs,
zobgs_no_grad=zobgs_no_grad,
fobgs=fobgs,
Expand Down

0 comments on commit a59045f

Please sign in to comment.