Skip to content

Commit

Permalink
more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerflex committed Aug 28, 2024
1 parent b211dc8 commit e70a93c
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 129 deletions.
48 changes: 14 additions & 34 deletions tests/test_components/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
FWIDTH = FREQ0 / 10

# sim sizes
LZ = 7 * WVL
LZ = 7.0 * WVL

IS_3D = False

Expand Down Expand Up @@ -420,9 +420,11 @@ def make_structures(params: anp.ndarray) -> dict[str, td.Structure]:
def make_monitors() -> dict[str, tuple[td.Monitor, typing.Callable[[td.SimulationData], float]]]:
"""Make a dictionary of all the possible monitors in the simulation."""

X = 0.75

mode_mnt = td.ModeMonitor(
size=(2, 2, 0),
center=(0, 0, LZ / 2 - WVL),
center=(0, 0, +LZ / 2 - X * WVL),
mode_spec=td.ModeSpec(),
freqs=[FREQ0],
name="mode",
Expand All @@ -444,34 +446,18 @@ def diff_postprocess_fn(sim_data, mnt_data):

field_vol = td.FieldMonitor(
size=(1, 1, 0),
center=(0, 0, +LZ / 2 - WVL),
center=(0, 0, +LZ / 2 - X * WVL),
freqs=[FREQ0],
name="field_vol",
)

# def field_vol_postprocess_fn(sim_data, mnt_data):
# value = 0.0
# for _, val in mnt_data.field_components.items():
# for key in ('Hx', 'Hy', 'Hz'):
# val = mnt_data.field_components[key]
# # # if key == 'Ex':
# # value += abs(anp.sum(abs(val.values)))
# # else:
# # value += 0 * abs(anp.sum(abs(val.values)))

# value += anp.sum(sim_data.get_intensity(mnt_data.monitor.name).values)
# return value

def field_vol_postprocess_fn(sim_data, mnt_data):
value = 0.0
# for _, val in mnt_data.field_components.items():
# value = value + abs(anp.sum(val.values))
# field components numerical is 3x higher
# intensity = anp.nan_to_num(anp.sum(sim_data.get_intensity(mnt_data.monitor.name).values))
# value += intensity
# intensity numerical is 4.79x higher
for _, val in mnt_data.field_components.items():
value = value + abs(anp.sum(val.values))
intensity = anp.nan_to_num(anp.sum(sim_data.get_intensity(mnt_data.monitor.name).values))
value += intensity
value += anp.sum(mnt_data.flux.values)
# flux is 18.4x lower
return value

field_point = td.FieldMonitor(
Expand All @@ -483,15 +469,9 @@ def field_vol_postprocess_fn(sim_data, mnt_data):

def field_point_postprocess_fn(sim_data, mnt_data):
value = 0.0
# for _, val in mnt_data.field_components.items():
for key in ("Ex", "Ey", "Ez", "Hx", "Hy", "Hz"):
val = mnt_data.field_components[key]
if key == "Ey":
value += abs(anp.sum(abs(val.values)))
else:
value += 0 * abs(anp.sum(abs(val.values)))

# value += anp.sum(sim_data.get_intensity(mnt_data.monitor.name).values)
for _, val in mnt_data.field_components.items():
value += abs(anp.sum(abs(val.values)))
value += anp.sum(sim_data.get_intensity(mnt_data.monitor.name).values)
return value

return dict(
Expand Down Expand Up @@ -618,7 +598,7 @@ def test_polyslab_axis_ops(axis):


@pytest.mark.skipif(not RUN_NUMERICAL, reason="Numerical gradient tests runs through web API.")
@pytest.mark.parametrize("structure_key, monitor_key", (("medium", "mode"),))
@pytest.mark.parametrize("structure_key, monitor_key", (("medium", "field_vol"),))
def test_autograd_numerical(structure_key, monitor_key):
"""Test an objective function through tidy3d autograd."""

Expand All @@ -642,7 +622,7 @@ def objective(*args):
assert anp.all(grad != 0.0), "some gradients are 0"

# numerical gradients
delta = 1e-3
delta = 1e-2
sims_numerical = {}

params_num = np.zeros((N_PARAMS, N_PARAMS))
Expand Down
74 changes: 15 additions & 59 deletions tidy3d/components/data/monitor_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from ...log import log
from ..base import TYPE_TAG_STR, cached_property, skip_if_fields_missing
from ..base_sim.data.monitor_data import AbstractMonitorData
from ..geometry.base import Box
from ..grid.grid import Coords, Grid
from ..medium import Medium, MediumType
from ..monitor import (
Expand Down Expand Up @@ -1069,80 +1068,33 @@ def to_adjoint_field_sources(self, fwidth: float) -> List[CustomCurrentSource]:

sources = []

# Define source geometry based on coordinates in the data
data_mins = []
data_maxs = []
source_geo = self.monitor.geometry
freqs = self.monitor.freqs

def shift_value(coords) -> float:
"""How much to shift the geometry by along a dimension (only if > 1D)."""
return SHIFT_VALUE_ADJ_FLD_SRC if len(coords) > 1 else 0

for _, field_component in self.field_components.items():
coords = field_component.coords
data_mins.append({key: min(val) + shift_value(val) for key, val in coords.items()})
data_maxs.append({key: max(val) + shift_value(val) for key, val in coords.items()})

rmin = []
rmax = []
for dim in "xyz":
rmin.append(max(val[dim] for val in data_mins))
rmax.append(min(val[dim] for val in data_maxs))

source_geo = Box.from_bounds(rmin=rmin, rmax=rmax)

# Define source dataset
# Offset coordinates by source center since local coords are assumed in CustomCurrentSource

for freq0 in tuple(self.field_components.values())[0].coords["f"]:
for freq0 in freqs:
src_field_components = {}
for name, field_component in self.field_components.items():
# get the VJP values at frequency and apply adjoint phase
field_component = field_component.sel(f=freq0)
forward_amps = field_component.values

# values = -1j * forward_amps
# # rms_error = 0.0012
# # |grad| / |grad_num| = 0.6020

# values = 1 * forward_amps
# # rms_error = 2.0000
# # |grad| / |grad_num| = 0.2679

# values = +1j * forward_amps
# # rms_error = 2.0000
# # |grad| / |grad_num| = 0.6020

# values = -1 * forward_amps
# # rms_error = 0.0013
# # |grad| / |grad_num| = 0.2679

# values = -1j * np.conj(forward_amps)
# # rms_error = 2.0000
# # |grad| / |grad_num| = 0.0193

values = np.conj(forward_amps)
# rms_error = 0.0012
# |grad| / |grad_num| = 0.9772

# values = +1j * np.conj(forward_amps)
# # rms_error = 0.0013
# # |grad| / |grad_num| = 0.0193

# values = -1 * np.conj(forward_amps)
# # rms_error = 2.0000
# # |grad| / |grad_num| = 0.9772
values = -1j * field_component.values

# make source go backwards
if "H" in name:
values *= -1

# make coords that are shifted relative to geometry (0,0,0) = geometry.center
coords = dict(field_component.coords.copy())
for dim, key in enumerate("xyz"):
coords[key] = np.array(coords[key]) - source_geo.center[dim]
coords["f"] = np.array([freq0])
values = np.expand_dims(values, axis=-1)

# ignore zero components
if not np.all(values == 0):
src_field_components[name] = ScalarFieldDataArray(values, coords=coords)

# construct custom Current source
dataset = FieldDataset(**src_field_components)

custom_source = CustomCurrentSource(
center=source_geo.center,
size=source_geo.size,
Expand Down Expand Up @@ -1982,6 +1934,10 @@ def make_adjoint_sources(
) -> List[Union[CustomCurrentSource, PointDipole]]:
"""Converts a :class:`.FieldData` to a list of adjoint current or point sources."""

# avoids error in edge case where there are extraneous flux monitors not used in objective
if np.all(self.flux.values == 0.0):
return []

raise NotImplementedError(
"Could not formulate adjoint source for 'FluxMonitor' output. To compute derivatives "
"with respect to flux data, please use a 'FieldMonitor' and call '.flux' on the "
Expand Down
59 changes: 23 additions & 36 deletions tidy3d/web/api/autograd/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@
# default value for whether to do local gradient calculation (True) or server side (False)
LOCAL_GRADIENT = True

# if True, will plot the adjoint fields on the plane provided. used for debugging only
_INSPECT_ADJOINT_FIELDS = False
_INSPECT_ADJOINT_PLANE = td.Box(center=(0, 0, 0), size=(td.inf, td.inf, 0))


def is_valid_for_autograd(simulation: td.Simulation) -> bool:
"""Check whether a supplied simulation can use autograd run."""
Expand Down Expand Up @@ -731,63 +735,46 @@ def setup_adj(

td.log.info("Running custom vjp (adjoint) pipeline.")

# import pdb; pdb.set_trace()

# immediately filter out any data_vjps with all 0's in the data
data_fields_vjp_old = {
key: get_static(value) for key, value in data_fields_vjp.items() if not np.all(value == 0.0)
}
data_fields_vjp = {key: get_static(value) for key, value in data_fields_vjp.items()}

# insert the raw VJP data into the .data of the original SimulationData
sim_data_vjp_old = sim_data_orig.insert_traced_fields(field_mapping=data_fields_vjp_old)
sim_data_vjp = sim_data_orig.insert_traced_fields(field_mapping=data_fields_vjp)

# make adjoint simulation from that SimulationData
data_vjp_paths_old = set(data_fields_vjp_old.keys())
data_vjp_paths = set(data_fields_vjp.keys())

num_monitors = len(sim_data_orig.simulation.monitors)
adjoint_monitors = sim_data_orig.simulation.with_adjoint_monitors(sim_fields_keys).monitors[
num_monitors:
]

sim_adj_old, adjoint_source_info_old = sim_data_vjp_old.make_adjoint_sim(
data_vjp_paths=data_vjp_paths_old, adjoint_monitors=adjoint_monitors
)

sim_adj, adjoint_source_info = sim_data_vjp.make_adjoint_sim(
data_vjp_paths=data_vjp_paths, adjoint_monitors=adjoint_monitors
)

# def get_current_amps(adjoint_source_info):
# vals = {}
# for key in ('Ex', 'Ey', 'Ez', 'Hx', 'Hy', 'Hz'):
# dataset = adjoint_source_info.sources[0].current_dataset
# if key in dataset.field_components:
# vals[key] = np.sum(abs(dataset.field_components[key]))
# return vals

# amps_old = get_current_amps(adjoint_source_info_old)
# amps_new = get_current_amps(adjoint_source_info)

# fld_mnt = td.FieldMonitor(
# size=(td.inf, td.inf, td.inf),
# freqs=adjoint_monitors[0].freqs,
# name='field',
# center=(0,0,0),
# )

# import tidy3d.web as web
# import matplotlib.pylab as plt
if _INSPECT_ADJOINT_FIELDS:
adj_fld_mnt = td.FieldMonitor(
center=_INSPECT_ADJOINT_PLANE.center,
size=_INSPECT_ADJOINT_PLANE.size,
freqs=adjoint_monitors[0].freqs,
name="adjoint_fields",
)

# sim_data_old = web.run(sim_adj_old.updated_copy(monitors=[fld_mnt]), task_name='old')
# sim_data_new = web.run(sim_adj.updated_copy(monitors=[fld_mnt]), task_name='old')
import matplotlib.pylab as plt

# sim_data_old.plot_field('field', 'E', 're')
# sim_data_new.plot_field('field', 'E', 're')
import tidy3d.web as web

# import pdb; pdb.set_trace()
sim_data_new = web.run(
sim_adj.updated_copy(monitors=[adj_fld_mnt]),
task_name="adjoint_field_viz",
verbose=False,
)
_, (ax1, ax2, ax3) = plt.subplots(1, 3, tight_layout=True, figsize=(10, 4))
sim_data_new.plot_field("adjoint_fields", "Ex", "re", ax=ax1)
sim_data_new.plot_field("adjoint_fields", "Ey", "re", ax=ax2)
sim_data_new.plot_field("adjoint_fields", "Ez", "re", ax=ax3)
plt.show()

td.log.info(f"Adjoint simulation created with {len(sim_adj.sources)} sources.")

Expand Down

0 comments on commit e70a93c

Please sign in to comment.