Skip to content

Commit

Permalink
update dflex/warp jac tests
Browse files Browse the repository at this point in the history
  • Loading branch information
imgeorgiev committed Jun 29, 2023
1 parent b42f92c commit f90f7a8
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
20 changes: 17 additions & 3 deletions examples/test_jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,14 @@ def f(inputs):
print("took {:.2f}".format(total_time))
print("jacobian shape", jac.shape)

# np.save("jac", jac.detach().cpu().numpy())
directory = "outputs"
if not os.path.exists(directory):
os.makedirs(directory)

filename = "jacs_{:}".format(args.env)
filename = f"{directory}/{filename}"
print("Saving to", filename)
np.save(filename, jac.detach().cpu().numpy())

for b in range(len(jac)):
for i in range(jac.shape[1]):
Expand Down Expand Up @@ -157,7 +164,14 @@ def example_jac2(args):
print("took {:.2f}".format(total_time))
print("jacobian shape", jac.shape)

# np.save("jac", jac.detach().cpu().numpy())
directory = "outputs"
if not os.path.exists(directory):
os.makedirs(directory)

filename = "jacs2_{:}".format(args.env)
filename = f"{directory}/{filename}"
print("Saving to", filename)
np.save(filename, jac.detach().cpu().numpy())

for b in range(len(jac)):
for i in range(jac.shape[1]):
Expand Down Expand Up @@ -238,7 +252,7 @@ def main(args):

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--env", type=str, default="CartPoleSwingUpEnv")
parser.add_argument("--env", type=str, default="HopperEnv")
parser.add_argument("--num-envs", type=int, default=1)
parser.add_argument("--test", default=False, action="store_true")

Expand Down
7 changes: 5 additions & 2 deletions examples/test_jacobian_warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.


## BROKEN

import sys, os

project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
Expand All @@ -29,7 +32,7 @@
from warp.envs import HopperEnv


def test_jac(args, num_steps):
def test_jac(args):
seeding()

# env_fn = getattr(envs, args.env)
Expand Down Expand Up @@ -128,7 +131,7 @@ def check_grad(fn, inputs, eps=1e-6, atol=1e-4, rtol=1e-6):


def main(args):
test_jac(args, 1)
test_jac(args)


if __name__ == "__main__":
Expand Down

0 comments on commit f90f7a8

Please sign in to comment.