Skip to content

Commit

Permalink
more cylinder testing
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerflex committed Nov 12, 2024
1 parent 5098bc3 commit e721e9b
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 18 deletions.
12 changes: 6 additions & 6 deletions tests/test_components/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,22 +837,22 @@ def objective(*args):
def test_autograd_polyslab_cylinder(use_emulated_run, monitor_key):
"""Test an objective function through tidy3d autograd."""

t = 1.0
t0 = 1.0
axis = 0

num_pts = 89
num_pts = 819

monitor, postprocess = make_monitors()[monitor_key]

def make_cylinder(radius, x0, y0):
def make_cylinder(radius, x0, y0, t):
return td.Cylinder(
center=td.Cylinder.unpop_axis(0.0, (x0, y0), axis=axis),
radius=radius,
length=t,
axis=axis,
) # .to_polyslab(num_pts)
).to_polyslab(num_pts)

def make_polyslab(radius, x0, y0):
def make_polyslab(radius, x0, y0, t):
phis = anp.linspace(0, 2 * np.pi, num_pts + 1)[:-1]

xs = radius * anp.cos(phis) + x0
Expand All @@ -872,7 +872,7 @@ def make_sim(params, geo_maker):

return SIM_BASE.updated_copy(structures=[structure], monitors=[monitor])

p0 = [1.0, 0.0, 0.0]
p0 = [1.0, 0.0, 0.0, t0]

def objective_polyslab(params):
"""Objective function."""
Expand Down
33 changes: 22 additions & 11 deletions tidy3d/components/geometry/polyslab.py
Original file line number Diff line number Diff line change
Expand Up @@ -1428,28 +1428,39 @@ def compute_derivative_slab_bounds(self, derivative_info: DerivativeInfo) -> Tra
normals_min = self.unpop_axis_vect(-ones, np.stack((zeros, zeros), axis=-1))
normals_max = self.unpop_axis_vect(+ones, np.stack((zeros, zeros), axis=-1))
perps1 = self.unpop_axis_vect(zeros, np.stack((ones, zeros), axis=-1))
perps2 = self.unpop_axis_vect(zeros, np.stack((ones, zeros), axis=-1))
perps2 = self.unpop_axis_vect(zeros, np.stack((zeros, ones), axis=-1))

# compute inside
xs, ys, _ = self.unpop_axis_vect(
0 * r1_centers, np.stack((r1_centers, r2_centers), axis=-1)
rr1_max, rr2_max, axx_max = np.meshgrid(r1_centers, r2_centers, ax_max)
rr1_min, rr2_min, axx_min = np.meshgrid(r1_centers, r2_centers, ax_min)

xx_max, yy_max, zz_max = self.unpop_axis_vect(
axx_max.flatten(),
np.stack((rr1_max.flatten(), rr2_max.flatten()), axis=-1),
).T

xx_min, yy_min, zz_min = self.unpop_axis_vect(
axx_min.flatten(),
np.stack((rr1_min.flatten(), rr2_min.flatten()), axis=-1),
).T
xx, yy, zz = np.meshgrid(xs, ys, self.center_axis)
inside = self.inside(xx, yy, zz).squeeze().flatten()
areas *= inside

inside_min = self.inside(xx_min, yy_min, zz_min).squeeze().flatten()
inside_max = self.inside(xx_max, yy_max, zz_max).squeeze().flatten()

areas_max = areas * inside_max
areas_min = areas * inside_min

# compute DerivativeSurfaceMesh for each top and bottom.
surface_mesh_min = DerivativeSurfaceMesh(
centers=edge_centers_xyz_min,
areas=areas,
areas=areas_min,
normals=normals_min,
perps1=perps1,
perps2=perps2,
)

surface_mesh_max = DerivativeSurfaceMesh(
centers=edge_centers_xyz_max,
areas=areas,
areas=areas_max,
normals=normals_max,
perps1=perps1,
perps2=perps2,
Expand All @@ -1458,8 +1469,8 @@ def compute_derivative_slab_bounds(self, derivative_info: DerivativeInfo) -> Tra
grads_min = derivative_info.grad_surfaces(surface_mesh=surface_mesh_min)
grads_max = derivative_info.grad_surfaces(surface_mesh=surface_mesh_max)

vjp_min = np.sum(grads_min).item()
vjp_max = np.sum(grads_max).item()
vjp_min = np.real(np.sum(grads_min).item())
vjp_max = np.real(np.sum(grads_max).item())

return [vjp_min, vjp_max]

Expand Down
2 changes: 1 addition & 1 deletion tidy3d/components/geometry/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldM
if path == ("length",):
vjps[path] = vjp_top - vjp_bot

if path == ("radius",):
elif path == ("radius",):
vjps[path] = vjp_xs + vjp_ys

elif "center" in path:
Expand Down

0 comments on commit e721e9b

Please sign in to comment.