From a59045f51421078318a8102a370584a3fe93fc2c Mon Sep 17 00:00:00 2001 From: Ignat Georgiev Date: Thu, 29 Jun 2023 11:47:35 -0400 Subject: [PATCH] update dflex and warp grad collect scripts --- scripts/cfg/env/hopper.yaml | 3 ++- scripts/grad_collect.py | 12 +++++++----- scripts/grad_collect_iter.py | 8 ++++---- scripts/grad_collect_multistep.py | 9 +++++---- scripts/grad_collect_multistep_single_theta.py | 6 ++++-- 5 files changed, 22 insertions(+), 16 deletions(-) diff --git a/scripts/cfg/env/hopper.yaml b/scripts/cfg/env/hopper.yaml index 24e0f316..f5d1ea27 100644 --- a/scripts/cfg/env/hopper.yaml +++ b/scripts/cfg/env/hopper.yaml @@ -1,10 +1,11 @@ name: df_hopper +env_name: HopperEnv config: _target_: shac.envs.HopperEnv render: ${general.render} device: ${general.device} - num_envs: ${resolve_default:512,${general.num_envs}} + num_envs: 512 # ${resolve_default:512,${general.num_envs}} seed: ${general.seed} episode_length: 1000 no_grad: False diff --git a/scripts/grad_collect.py b/scripts/grad_collect.py index 8d15cb65..ae508375 100644 --- a/scripts/grad_collect.py +++ b/scripts/grad_collect.py @@ -14,7 +14,7 @@ def main(config: DictConfig): torch.random.manual_seed(config.general.seed) # create environment - env = instantiate(config.env) + env = instantiate(config.env.config) n = env.num_obs m = env.num_acts @@ -31,14 +31,14 @@ def main(config: DictConfig): losses = [] baseline = [] - hh = np.arange(1, config.env.episode_length + 1, h_step) + hh = np.arange(1, config.env.config.episode_length + 1, h_step) for h in tqdm(hh): env.clear_grad() env.reset() ww = w.clone() ww.requires_grad_(True) - loss = torch.zeros(config.env.num_envs).to(device) + loss = torch.zeros(config.env.config.num_envs).to(device) # apply first noisy action obs, rew, done, info = env.step(ww) @@ -65,8 +65,10 @@ def main(config: DictConfig): zobg = 1 / std**2 * (loss.unsqueeze(1) - loss[0]) * ww zobgs.append(zobg.detach().cpu().numpy()) - filename = "{:}_grads_{:}".format(env.__class__.__name__, config.env.episode_length) - if "warp" in config.env._target_: + filename = "{:}_grads_{:}".format( + env.__class__.__name__, config.env.config.episode_length + ) + if "warp" in config.env.config._target_: filename = "Warp" + filename filename = f"outputs/grads/{filename}" if hasattr(env, "start_state"): diff --git a/scripts/grad_collect_iter.py b/scripts/grad_collect_iter.py index 2ea834f1..b808e934 100644 --- a/scripts/grad_collect_iter.py +++ b/scripts/grad_collect_iter.py @@ -14,7 +14,7 @@ def main(config: DictConfig): torch.random.manual_seed(config.general.seed) - env = instantiate(config.env) + env = instantiate(config.env.config) n = env.num_obs m = env.num_acts @@ -36,7 +36,7 @@ def main(config: DictConfig): ww = w.clone() ww.requires_grad_(True) - loss = torch.zeros(config.env.num_envs).to(device) + loss = torch.zeros(config.env.config.num_envs).to(device) # apply first noisy action obs, rew, done, info = env.step(ww) @@ -62,9 +62,9 @@ def main(config: DictConfig): zobgs.append(zobg.detach().cpu().numpy()) filename = "{:}_grads2_{:}".format( - env.__class__.__name__, config.env.episode_length + env.__class__.__name__, config.env.config.episode_length ) - if "warp" in config.env._target_: + if "warp" in config.env.config._target_: filename = "Warp" + filename filename = f"outputs/grads/{filename}" if hasattr(env, "start_state"): diff --git a/scripts/grad_collect_multistep.py b/scripts/grad_collect_multistep.py index 88fef200..4bc10da3 100644 --- a/scripts/grad_collect_multistep.py +++ b/scripts/grad_collect_multistep.py @@ -7,7 +7,8 @@ import torch from tqdm import tqdm from torchviz import make_dot -from shac.envs import DFlexEnv, WarpEnv +from shac.envs import DFlexEnv +from warp.envs import WarpEnv @hydra.main(version_base="1.2", config_path="cfg", config_name="config.yaml") @@ -68,7 +69,7 @@ def policy(obs): (dpi,) = torch.autograd.grad(policy(obs.detach()).sum(), th) dpis.append(dpi) action = policy(obs) + w[t] - obs, rew, done, info = env.step(action) + obs, rew, term, trunc, info = env.step(action) loss += rew # NOTE: commented out code below is for the debugging of more efficient grad computation # make_dot(loss.sum(), show_attrs=True, show_saved=True).render("correct_graph") @@ -107,9 +108,9 @@ def policy(obs): # Save data filename = "{:}_grads_ms_{:}".format( - env.__class__.__name__, config.env.episode_length + env.__class__.__name__, config.env.config.episode_length ) - if "warp" in config.env._target_: + if "warp" in config.env.config._target_: filename = "Warp" + filename filename = f"outputs/grads/{filename}" if hasattr(env, "start_state"): diff --git a/scripts/grad_collect_multistep_single_theta.py b/scripts/grad_collect_multistep_single_theta.py index 0f9bab75..043af83d 100644 --- a/scripts/grad_collect_multistep_single_theta.py +++ b/scripts/grad_collect_multistep_single_theta.py @@ -17,7 +17,7 @@ def main(config: DictConfig): torch.random.manual_seed(config.general.seed) # create environment - env: Union[DFlexEnv, WarpEnv] = instantiate(config.env) + env: Union[DFlexEnv, WarpEnv] = instantiate(config.env.config) n = env.num_obs m = env.num_acts @@ -103,7 +103,9 @@ def policy(obs): zobgs_no_grad.append(zobg_no_grad.detach().cpu().numpy()) np.savez( - "{:}_grads_ms_{:}".format(env.__class__.__name__, config.env.episode_length), + "{:}_grads_ms_{:}".format( + env.__class__.__name__, config.env.config.episode_length + ), zobgs=zobgs, zobgs_no_grad=zobgs_no_grad, fobgs=fobgs,