Skip to content

Commit

Permalink
fix warp_utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
krishpop committed Jan 22, 2023
1 parent 26f9f28 commit b93ac36
Showing 1 changed file with 101 additions and 80 deletions.
181 changes: 101 additions & 80 deletions src/shac/utils/warp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

PI = wp.constant(np.pi)


@wp.kernel
def assign_kernel(
b: wp.array(dtype=float),
Expand Down Expand Up @@ -35,7 +36,7 @@ def assign_act_kernel(
):
tid = wp.tid()
a[2 * tid] = b[tid, 0]
a[2 * tid + 1] = 0.
a[2 * tid + 1] = 0.0


def float_assign_joint_act(a, b):
Expand All @@ -51,9 +52,9 @@ def float_assign_joint_act(a, b):

@wp.kernel
def assign_transform_kernel(
b: wp.array(dtype=wp.transform),
# outputs
a: wp.array(dtype=wp.transform),
b: wp.array(dtype=wp.transform),
# outputs
a: wp.array(dtype=wp.transform),
):
tid = wp.tid()
a[tid] = b[tid]
Expand All @@ -69,15 +70,17 @@ def transform_assign(a, b):
)
return a


@wp.kernel
def assign_spatial_kernel(
b: wp.array(dtype=wp.spatial_vector),
# outputs
a: wp.array(dtype=wp.spatial_vector),
b: wp.array(dtype=wp.spatial_vector),
# outputs
a: wp.array(dtype=wp.spatial_vector),
):
tid = wp.tid()
a[tid] = b[tid]


def spatial_assign(a, b):
wp.launch(
assign_spatial_kernel,
Expand All @@ -88,12 +91,11 @@ def spatial_assign(a, b):
)
return a


@wp.func
def compute_joint_q(
X_wp: wp.transform,
X_wc: wp.transform,
axis: wp.vec3,
rotation_count: float):
X_wp: wp.transform, X_wc: wp.transform, axis: wp.vec3, rotation_count: float
):

# child transform and moment arm
q_p = wp.transform_get_rotation(X_wp)
Expand All @@ -114,12 +116,12 @@ def compute_joint_q(

@wp.func
def compute_joint_qd(
X_wp: wp.transform,
X_wc: wp.transform,
w_p: wp.vec3,
w_c: wp.vec3,
axis: wp.vec3,
):
X_wp: wp.transform,
X_wc: wp.transform,
w_p: wp.vec3,
w_c: wp.vec3,
axis: wp.vec3,
):
axis_p = wp.transform_vector(X_wp, axis)
# angular error vel
w_err = w_c - w_p
Expand All @@ -129,20 +131,21 @@ def compute_joint_qd(

@wp.kernel
def get_joint_q(
body_q: wp.array(dtype=wp.transform),
joint_type: wp.array(dtype=int),
joint_parent: wp.array(dtype=int),
joint_X_p: wp.array(dtype=wp.transform),
joint_axis: wp.array(dtype=wp.vec3),
joint_rotation_count: float,
# outputs
joint_q: wp.array(dtype=float)):
body_q: wp.array(dtype=wp.transform),
joint_type: wp.array(dtype=int),
joint_parent: wp.array(dtype=int),
joint_X_p: wp.array(dtype=wp.transform),
joint_axis: wp.array(dtype=wp.vec3),
joint_rotation_count: float,
# outputs
joint_q: wp.array(dtype=float),
):

tid = wp.tid()
type = joint_type[tid]
axis = joint_axis[tid]

if type != wp.sim.JOINT_REVOLUTE_TIGHT and type != wp.sim.JOINT_REVOLUTE:
if type != wp.sim.JOINT_REVOLUTE:
return

c_child = tid
Expand All @@ -159,24 +162,22 @@ def get_joint_q(

@wp.kernel
def get_joint_qd(
body_q: wp.array(dtype=wp.transform),
body_qd: wp.array(dtype=wp.spatial_vector),
joint_qd_start: wp.array(dtype=int),
joint_type: wp.array(dtype=int),
joint_parent: wp.array(dtype=int),
joint_X_p: wp.array(dtype=wp.transform),
joint_axis: wp.array(dtype=wp.vec3),
# outputs
joint_qd: wp.array(dtype=float)):
body_q: wp.array(dtype=wp.transform),
body_qd: wp.array(dtype=wp.spatial_vector),
joint_qd_start: wp.array(dtype=int),
joint_type: wp.array(dtype=int),
joint_parent: wp.array(dtype=int),
joint_X_p: wp.array(dtype=wp.transform),
joint_axis: wp.array(dtype=wp.vec3),
# outputs
joint_qd: wp.array(dtype=float),
):

tid = wp.tid()
type = joint_type[tid]
qd_start = joint_qd_start[tid]
axis = joint_axis[tid]

if type != wp.sim.JOINT_REVOLUTE_TIGHT:
return

c_child = tid
c_parent = joint_parent[tid]
X_wp = joint_X_p[tid]
Expand All @@ -196,16 +197,18 @@ def get_joint_qd(

class IntegratorSimulate(torch.autograd.Function):
@staticmethod
def forward(ctx,
model,
state_in,
integrator,
dt,
substeps,
act,
body_q,
body_qd,
state_out,):
def forward(
ctx,
model,
state_in,
integrator,
dt,
substeps,
act,
body_q,
body_qd,
state_out,
):
ctx.tape = wp.Tape()
ctx.model = model
ctx.act = wp.from_torch(act)
Expand All @@ -227,7 +230,7 @@ def forward(ctx,
ctx.model.shape_materials.kd.requires_grad = True
ctx.model.shape_materials.kf.requires_grad = True
ctx.model.shape_materials.mu.requires_grad = True
ctx.model.shape_materials.restitution.requires_grad = True
ctx.model.shape_materials.restitution.requires_grad = True

with ctx.tape:
float_assign_joint_act(ctx.model.joint_act, ctx.act)
Expand All @@ -240,43 +243,52 @@ def forward(ctx,
state_in.clear_forces()
state_temp = model.state(requires_grad=True)
state_temp = integrator.simulate(
ctx.model, state_in, state_temp, dt / float(substeps),
requires_grad=True
ctx.model,
state_in,
state_temp,
dt / float(substeps),
requires_grad=True,
)
state_in = state_temp
state_in.clear_forces()
# updates joint_q joint_qd
ctx.state_out = integrator.simulate(ctx.model, state_in, state_out, dt / float(substeps), requires_grad=True)
ctx.state_out = integrator.simulate(
ctx.model, state_in, state_out, dt / float(substeps), requires_grad=True
)
# TODO: Check if calling collide after running substeps is correct
if ctx.model.ground:
wp.sim.collide(ctx.model, ctx.state_out)
# wp.sim.eval_ik(ctx.model, ctx.state_out, ctx.joint_q_end, ctx.joint_qd_end)

wp.launch(kernel=get_joint_q,
dim=model.joint_count,
device=model.device,
inputs=[
ctx.state_out.body_q,
model.joint_type,
model.joint_parent,
model.joint_X_p,
model.joint_axis,
0.
],
outputs=[ctx.joint_q_end])
wp.launch(kernel=get_joint_qd,
dim=model.joint_count,
device=model.device,
inputs=[
ctx.state_out.body_q,
ctx.state_out.body_qd,
model.joint_qd_start,
model.joint_type,
model.joint_parent,
model.joint_X_p,
model.joint_axis,
],
outputs=[ctx.joint_qd_end])
wp.launch(
kernel=get_joint_q,
dim=model.joint_count,
device=model.device,
inputs=[
ctx.state_out.body_q,
model.joint_type,
model.joint_parent,
model.joint_X_p,
model.joint_axis,
0.0,
],
outputs=[ctx.joint_q_end],
)
wp.launch(
kernel=get_joint_qd,
dim=model.joint_count,
device=model.device,
inputs=[
ctx.state_out.body_q,
ctx.state_out.body_qd,
model.joint_qd_start,
model.joint_type,
model.joint_parent,
model.joint_X_p,
model.joint_axis,
],
outputs=[ctx.joint_qd_end],
)
joint_q_end = wp.to_torch(ctx.joint_q_end)
joint_qd_end = wp.to_torch(ctx.joint_qd_end)
return (
Expand All @@ -302,16 +314,25 @@ def backward(ctx, adj_joint_q, adj_joint_qd, _a):

ctx.tape.zero()
# return adjoint w.r.t. inputs
return (None, None, None, None, None, joint_act_grad, body_q_grad, body_qd_grad, None)
return (
None,
None,
None,
None,
None,
joint_act_grad,
body_q_grad,
body_qd_grad,
None,
)


def check_grads(wp_struct):
for var in wp_struct.__dict__:
if isinstance(getattr(wp_struct, var), wp.array):
arr = getattr(wp_struct, var)
if arr.requires_grad:
assert (np.count_nonzero(arr.grad.numpy()) == 0), "var grad is non_zero"
assert np.count_nonzero(arr.grad.numpy()) == 0, "var grad is non_zero"
else:
if arr.dtype in [wp.vec3, wp.vec4, float, wp.float32]:
print(var)

0 comments on commit b93ac36

Please sign in to comment.