Skip to content

Commit

Permalink
add cylinder differentiation
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerflex committed Aug 22, 2024
1 parent 0d3d462 commit c926dce
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added value_and_grad function to the autograd plugin, importable via `from tidy3d.plugins.autograd import value_and_grad`. Supports differentiating functions with auxiliary data (`value_and_grad(f, has_aux=True)`).
- `Simulation.num_computational_grid_points` property to examine the number of grid cells that compose the computational domain corresponding to the simulation. This can differ from `Simulation.num_cells` based on boundary conditions and symmetries.
- Support for `dilation` argument in `JaxPolySlab`.
- Support for autograd differentiation with respect to `Cylinder.radius` and `Cylinder.center` (for elements not along axis dimension).
- `Cylinder.to_polyslab(num_pts_radius, **kwargs)` to convert a cylinder into a discretized version represented by a `PolySlab`.

### Changed
- `PolySlab` now raises error when differentiating and dilation causes damage to the polygon.
Expand Down
16 changes: 15 additions & 1 deletion tests/test_components/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,17 @@ def make_structures(params: anp.ndarray) -> dict[str, td.Structure]:
custom_med_pole_res = td.CustomPoleResidue(eps_inf=eps_inf, poles=[(a1, c1), (a2, c2)])
custom_pole_res = td.Structure(geometry=box, medium=custom_med_pole_res)

radius = 1 + anp.abs(vector @ params)
cyl_center_y = vector @ params
cyl_center_z = -vector @ params
cylinder_geo = td.Cylinder(
radius=radius,
center=(0, cyl_center_y, cyl_center_z),
axis=0,
length=LX / 2 if IS_3D else td.inf,
)
cylinder = td.Structure(geometry=cylinder_geo, medium=med)

return dict(
medium=medium,
center_list=center_list,
Expand All @@ -382,6 +393,7 @@ def make_structures(params: anp.ndarray) -> dict[str, td.Structure]:
geo_group=geo_group,
pole_res=pole_res,
custom_pole_res=custom_pole_res,
cylinder=cylinder,
)


Expand Down Expand Up @@ -471,6 +483,7 @@ def plot_sim(sim: td.Simulation, plot_eps: bool = False) -> None:
"geo_group",
"pole_res",
"custom_pole_res",
"cylinder",
)
monitor_keys_ = ("mode", "diff", "field_vol", "field_point")

Expand Down Expand Up @@ -533,7 +546,8 @@ def postprocess(data: td.SimulationData) -> float:
return dict(sim=make_sim, postprocess=postprocess)


@pytest.mark.parametrize("structure_key, monitor_key", args)
# @pytest.mark.parametrize("structure_key, monitor_key", args)
@pytest.mark.parametrize("structure_key, monitor_key", (("cylinder", "mode"),))
def test_autograd_objective(use_emulated_run, structure_key, monitor_key):
"""Test an objective function through tidy3d autograd."""

Expand Down
4 changes: 4 additions & 0 deletions tests/test_components/test_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,10 @@ def test_slanted_cylinder_infinite_length_validate():
)


def test_cylinder_to_polyslab():
ps = CYLINDER.to_polyslab(num_pts_radius=10, sidewall_angle=0.02)


def test_box_from_bounds():
b = td.Box.from_bounds(rmin=(-td.inf, 0, 0), rmax=(td.inf, 0, 0))
assert b.center[0] == 0.0
Expand Down
108 changes: 107 additions & 1 deletion tidy3d/components/geometry/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,30 @@
from math import isclose
from typing import List

import autograd.numpy as anp
import numpy as np
import pydantic.v1 as pydantic
import shapely

from ...constants import LARGE_NUMBER, MICROMETER
from ...exceptions import SetupError, ValidationError
from ...packaging import verify_packages_import
from ..autograd import AutogradFieldMap, TracedSize1D
from ..autograd.derivative_utils import DerivativeInfo
from ..base import cached_property, skip_if_fields_missing
from ..types import Axis, Bound, Coordinate, MatrixReal4x4, Shapely, Tuple
from . import base
from .polyslab import PolySlab

# for sampling conical frustum in visualization
_N_SAMPLE_CURVE_SHAPELY = 40

# for shapely circular shapes discretization in visualization
_N_SHAPELY_QUAD_SEGS = 200

# number of points to discretize polyslab used to represent cylinder in `Cylinder.to_polyslab()`
N_PTS_CYLINDER_POLYSLAB = 101


class Sphere(base.Centered, base.Circular):
"""Spherical geometry.
Expand Down Expand Up @@ -185,7 +192,7 @@ class Cylinder(base.Centered, base.Circular, base.Planar):
"""

# Provide more explanations on where radius is defined
radius: pydantic.NonNegativeFloat = pydantic.Field(
radius: TracedSize1D = pydantic.Field(
...,
title="Radius",
description="Radius of geometry at the ``reference_plane``.",
Expand Down Expand Up @@ -215,6 +222,105 @@ def _only_middle_for_infinite_length_slanted_cylinder(cls, val, values):
)
return val

def to_polyslab(self, num_pts_radius: int = N_PTS_CYLINDER_POLYSLAB, **kwargs) -> PolySlab:
"""Convert instance of ``Cylinder`` into a discretized version using ``PolySlab``.
Parameters
----------
num_pts_radius : int = 100
Number of points in the radius of the discretized polyslab.
**kwargs:
Extra keyword arguments passed to ``PolySlab()``, such as ``sidewall_angle``.
Returns
-------
PolySlab
Extruded polygon representing a discretized version of the cylinder.
"""

center_axis = self.center_axis
length_axis = self.length_axis
slab_bounds = (center_axis - length_axis / 2.0, center_axis + length_axis / 2.0)

if num_pts_radius < 3:
raise ValueError("'PolySlab' from 'Cylinder' must have 3 or more radius points.")

_, (x0, y0) = self.pop_axis(self.center, axis=self.axis)

xs_, ys_ = self._points_unit_circle(num_pts_radius=num_pts_radius)

xs = x0 + self.radius * xs_
ys = y0 + self.radius * ys_

vertices = anp.stack((xs, ys), axis=-1)

return PolySlab(
vertices=vertices,
axis=self.axis,
slab_bounds=slab_bounds,
**kwargs,
)

def _points_unit_circle(self, num_pts_radius: int = N_PTS_CYLINDER_POLYSLAB) -> np.ndarray:
"""Set of x and y points for the unit circle when discretizing cylinder as a polyslab."""
angles = np.linspace(0, 1, num_pts_radius + 1)[:-1]
xs = np.cos(angles)
ys = np.sin(angles)
return np.stack((xs, ys), axis=0)

def compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap:
"""Compute the adjoint derivatives for this object."""

polyslab = self.to_polyslab()

derivative_info_polyslab = derivative_info.updated_copy(paths=[("vertices",)])
vjps_polyslab = polyslab.compute_derivatives(derivative_info_polyslab)

vjps_vertices_xs, vjps_vertices_ys = vjps_polyslab[("vertices",)].T

vjps = {}
for path in derivative_info.paths:
if path == ("radius",):
xs_, ys_ = self._points_unit_circle()

vjp_xs = np.sum(xs_ * vjps_vertices_xs)
vjp_ys = np.sum(ys_ * vjps_vertices_ys)

vjps[path] = vjp_xs + vjp_ys

elif "center" in path:
_, center_index = path
if center_index == self.axis:
raise NotImplementedError(
"Can't current differentiate Cylinder with respect to its 'center' along "
"the axis. If you would like this feature added, please feel free to raise "
"an issue on the tidy3d front end repository."
)

_, (index_x, index_y) = self.pop_axis((0, 1, 2), axis=self.axis)
if center_index == index_x:
vjps[path] = np.sum(vjp_xs)
elif center_index == index_y:
vjps[path] = np.sum(vjp_ys)
else:
raise ValueError(
"Something unexpected happened. Was asked to differentiate "
f"with respect to 'Cylinder.center[{center_index}]', but this was not "
"detected as being one of the parallel axis with "
f"'Cylinder.axis' of '{self.axis}'. If you received this error, please raise "
"an issue on the tidy3d front end repository with details about how you "
"defined your 'Cylinder' in the objective function."
)

else:
raise NotImplementedError(
f"Differentiation with respect to 'Cylinder' '{path}' field not supported. "
"If you would like this feature added, please feel free to raise "
"an issue on the tidy3d front end repository."
)

return vjps

@property
def center_axis(self):
"""Gets the position of the center of the geometry in the out of plane dimension."""
Expand Down

0 comments on commit c926dce

Please sign in to comment.