Skip to content

Commit

Permalink
fix FieldData gradients
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerflex committed Aug 27, 2024
1 parent 22219ee commit 505adb7
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 11 deletions.
37 changes: 28 additions & 9 deletions tests/test_components/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
TEST_POLYSLAB_SPEED = False

# whether to run numerical gradient tests, off by default because it runs real simulations
RUN_NUMERICAL = False
RUN_NUMERICAL = True

TEST_MODES = ("pipeline", "adjoint", "speed")
TEST_MODE = "speed" if TEST_POLYSLAB_SPEED else "pipeline"
Expand Down Expand Up @@ -449,13 +449,26 @@ def diff_postprocess_fn(sim_data, mnt_data):
name="field_vol",
)

# def field_vol_postprocess_fn(sim_data, mnt_data):
# value = 0.0
# for _, val in mnt_data.field_components.items():
# for key in ('Hx', 'Hy', 'Hz'):
# val = mnt_data.field_components[key]
# # # if key == 'Ex':
# # value += abs(anp.sum(abs(val.values)))
# # else:
# # value += 0 * abs(anp.sum(abs(val.values)))

# value += anp.sum(sim_data.get_intensity(mnt_data.monitor.name).values)
# return value

def field_vol_postprocess_fn(sim_data, mnt_data):
value = 0.0
for _, val in mnt_data.field_components.items():
value = value + abs(anp.sum(val.values))
# for _, val in mnt_data.field_components.items():
# value = value + abs(anp.sum(val.values))
# field components numerical is 3x higher
intensity = anp.nan_to_num(anp.sum(sim_data.get_intensity(mnt_data.monitor.name).values))
value += intensity
# intensity = anp.nan_to_num(anp.sum(sim_data.get_intensity(mnt_data.monitor.name).values))
# value += intensity
# intensity numerical is 4.79x higher
value += anp.sum(mnt_data.flux.values)
# flux is 18.4x lower
Expand All @@ -470,9 +483,15 @@ def field_vol_postprocess_fn(sim_data, mnt_data):

def field_point_postprocess_fn(sim_data, mnt_data):
value = 0.0
for _, val in mnt_data.field_components.items():
value += abs(anp.sum(val.values))
value += anp.sum(sim_data.get_intensity(mnt_data.monitor.name).values)
# for _, val in mnt_data.field_components.items():
for key in ("Ex", "Ey", "Ez", "Hx", "Hy", "Hz"):
val = mnt_data.field_components[key]
if key == "Ey":
value += abs(anp.sum(abs(val.values)))
else:
value += 0 * abs(anp.sum(abs(val.values)))

# value += anp.sum(sim_data.get_intensity(mnt_data.monitor.name).values)
return value

return dict(
Expand Down Expand Up @@ -599,7 +618,7 @@ def test_polyslab_axis_ops(axis):


@pytest.mark.skipif(not RUN_NUMERICAL, reason="Numerical gradient tests runs through web API.")
@pytest.mark.parametrize("structure_key, monitor_key", (("cylinder", "mode"),))
@pytest.mark.parametrize("structure_key, monitor_key", (("medium", "field_vol"),))
def test_autograd_numerical(structure_key, monitor_key):
"""Test an objective function through tidy3d autograd."""

Expand Down
4 changes: 4 additions & 0 deletions tidy3d/components/data/monitor_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,12 +1093,16 @@ def shift_value(coords) -> float:
# Define source dataset
# Offset coordinates by source center since local coords are assumed in CustomCurrentSource

# import pdb; pdb.set_trace()

for freq0 in tuple(self.field_components.values())[0].coords["f"]:
src_field_components = {}
for name, field_component in self.field_components.items():
field_component = field_component.sel(f=freq0)
forward_amps = field_component.values
values = -1j * forward_amps
if "H" in name:
values *= -1
coords = dict(field_component.coords.copy())
for dim, key in enumerate("xyz"):
coords[key] = np.array(coords[key]) - source_geo.center[dim]
Expand Down
5 changes: 4 additions & 1 deletion tidy3d/components/data/sim_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
class AdjointSourceInfo(Tidy3dBaseModel):
"""Stores information about the adjoint sources to pass to autograd pipeline."""

sources: tuple[SourceType, ...] = pd.Field(
sources: Tuple[annotate_type(SourceType), ...] = pd.Field(
...,
title="Adjoint Sources",
description="Set of processed sources to include in the adjoint simulation.",
Expand Down Expand Up @@ -1117,6 +1117,9 @@ def process_adjoint_sources(self, adj_srcs: list[SourceType]) -> AdjointSourceIn
adj_srcs, post_norm = self.process_adjoint_sources_broadband(adj_srcs)
return AdjointSourceInfo(sources=adj_srcs, post_norm=post_norm, normalize_sim=True)

import pdb

pdb.set_trace()
# if several spatial ports and several frequencies, try to fit
log.info("Adjoint source creation: trying multifrequency fit.")
adj_srcs, post_norm = self.process_adjoint_sources_fit(
Expand Down
40 changes: 39 additions & 1 deletion tidy3d/web/api/autograd/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,26 +731,64 @@ def setup_adj(

td.log.info("Running custom vjp (adjoint) pipeline.")

# import pdb; pdb.set_trace()

# immediately filter out any data_vjps with all 0's in the data
data_fields_vjp = {
data_fields_vjp_old = {
key: get_static(value) for key, value in data_fields_vjp.items() if not np.all(value == 0.0)
}
data_fields_vjp = {key: get_static(value) for key, value in data_fields_vjp.items()}

# insert the raw VJP data into the .data of the original SimulationData
sim_data_vjp_old = sim_data_orig.insert_traced_fields(field_mapping=data_fields_vjp_old)
sim_data_vjp = sim_data_orig.insert_traced_fields(field_mapping=data_fields_vjp)

# make adjoint simulation from that SimulationData
data_vjp_paths_old = set(data_fields_vjp_old.keys())
data_vjp_paths = set(data_fields_vjp.keys())

num_monitors = len(sim_data_orig.simulation.monitors)
adjoint_monitors = sim_data_orig.simulation.with_adjoint_monitors(sim_fields_keys).monitors[
num_monitors:
]

sim_adj_old, adjoint_source_info_old = sim_data_vjp_old.make_adjoint_sim(
data_vjp_paths=data_vjp_paths_old, adjoint_monitors=adjoint_monitors
)

sim_adj, adjoint_source_info = sim_data_vjp.make_adjoint_sim(
data_vjp_paths=data_vjp_paths, adjoint_monitors=adjoint_monitors
)

# def get_current_amps(adjoint_source_info):
# vals = {}
# for key in ('Ex', 'Ey', 'Ez', 'Hx', 'Hy', 'Hz'):
# dataset = adjoint_source_info.sources[0].current_dataset
# if key in dataset.field_components:
# vals[key] = np.sum(abs(dataset.field_components[key]))
# return vals

# amps_old = get_current_amps(adjoint_source_info_old)
# amps_new = get_current_amps(adjoint_source_info)

# fld_mnt = td.FieldMonitor(
# size=(td.inf, td.inf, td.inf),
# freqs=adjoint_monitors[0].freqs,
# name='field',
# center=(0,0,0),
# )

# import tidy3d.web as web
# import matplotlib.pylab as plt

# sim_data_old = web.run(sim_adj_old.updated_copy(monitors=[fld_mnt]), task_name='old')
# sim_data_new = web.run(sim_adj.updated_copy(monitors=[fld_mnt]), task_name='old')

# sim_data_old.plot_field('field', 'E', 're')
# sim_data_new.plot_field('field', 'E', 're')

# import pdb; pdb.set_trace()

td.log.info(f"Adjoint simulation created with {len(sim_adj.sources)} sources.")

return sim_adj, adjoint_source_info
Expand Down

0 comments on commit 505adb7

Please sign in to comment.