From 505adb7cabdc3270a986daacd9dd4aabe25e9509 Mon Sep 17 00:00:00 2001 From: Tyler Hughes Date: Tue, 27 Aug 2024 12:56:30 -0400 Subject: [PATCH] fix FieldData gradients --- tests/test_components/test_autograd.py | 37 ++++++++++++++++++------ tidy3d/components/data/monitor_data.py | 4 +++ tidy3d/components/data/sim_data.py | 5 +++- tidy3d/web/api/autograd/autograd.py | 40 +++++++++++++++++++++++++- 4 files changed, 75 insertions(+), 11 deletions(-) diff --git a/tests/test_components/test_autograd.py b/tests/test_components/test_autograd.py index 58ccb4c1f..6b7cf4263 100644 --- a/tests/test_components/test_autograd.py +++ b/tests/test_components/test_autograd.py @@ -36,7 +36,7 @@ TEST_POLYSLAB_SPEED = False # whether to run numerical gradient tests, off by default because it runs real simulations -RUN_NUMERICAL = False +RUN_NUMERICAL = True TEST_MODES = ("pipeline", "adjoint", "speed") TEST_MODE = "speed" if TEST_POLYSLAB_SPEED else "pipeline" @@ -449,13 +449,26 @@ def diff_postprocess_fn(sim_data, mnt_data): name="field_vol", ) + # def field_vol_postprocess_fn(sim_data, mnt_data): + # value = 0.0 + # for _, val in mnt_data.field_components.items(): + # for key in ('Hx', 'Hy', 'Hz'): + # val = mnt_data.field_components[key] + # # # if key == 'Ex': + # # value += abs(anp.sum(abs(val.values))) + # # else: + # # value += 0 * abs(anp.sum(abs(val.values))) + + # value += anp.sum(sim_data.get_intensity(mnt_data.monitor.name).values) + # return value + def field_vol_postprocess_fn(sim_data, mnt_data): value = 0.0 - for _, val in mnt_data.field_components.items(): - value = value + abs(anp.sum(val.values)) + # for _, val in mnt_data.field_components.items(): + # value = value + abs(anp.sum(val.values)) # field components numerical is 3x higher - intensity = anp.nan_to_num(anp.sum(sim_data.get_intensity(mnt_data.monitor.name).values)) - value += intensity + # intensity = anp.nan_to_num(anp.sum(sim_data.get_intensity(mnt_data.monitor.name).values)) + # value += intensity # intensity numerical is 4.79x higher value += anp.sum(mnt_data.flux.values) # flux is 18.4x lower @@ -470,9 +483,15 @@ def field_vol_postprocess_fn(sim_data, mnt_data): def field_point_postprocess_fn(sim_data, mnt_data): value = 0.0 - for _, val in mnt_data.field_components.items(): - value += abs(anp.sum(val.values)) - value += anp.sum(sim_data.get_intensity(mnt_data.monitor.name).values) + # for _, val in mnt_data.field_components.items(): + for key in ("Ex", "Ey", "Ez", "Hx", "Hy", "Hz"): + val = mnt_data.field_components[key] + if key == "Ey": + value += abs(anp.sum(abs(val.values))) + else: + value += 0 * abs(anp.sum(abs(val.values))) + + # value += anp.sum(sim_data.get_intensity(mnt_data.monitor.name).values) return value return dict( @@ -599,7 +618,7 @@ def test_polyslab_axis_ops(axis): @pytest.mark.skipif(not RUN_NUMERICAL, reason="Numerical gradient tests runs through web API.") -@pytest.mark.parametrize("structure_key, monitor_key", (("cylinder", "mode"),)) +@pytest.mark.parametrize("structure_key, monitor_key", (("medium", "field_vol"),)) def test_autograd_numerical(structure_key, monitor_key): """Test an objective function through tidy3d autograd.""" diff --git a/tidy3d/components/data/monitor_data.py b/tidy3d/components/data/monitor_data.py index 8c3fc1299..5dba1c19c 100644 --- a/tidy3d/components/data/monitor_data.py +++ b/tidy3d/components/data/monitor_data.py @@ -1093,12 +1093,16 @@ def shift_value(coords) -> float: # Define source dataset # Offset coordinates by source center since local coords are assumed in CustomCurrentSource + # import pdb; pdb.set_trace() + for freq0 in tuple(self.field_components.values())[0].coords["f"]: src_field_components = {} for name, field_component in self.field_components.items(): field_component = field_component.sel(f=freq0) forward_amps = field_component.values values = -1j * forward_amps + if "H" in name: + values *= -1 coords = dict(field_component.coords.copy()) for dim, key in enumerate("xyz"): coords[key] = np.array(coords[key]) - source_geo.center[dim] diff --git a/tidy3d/components/data/sim_data.py b/tidy3d/components/data/sim_data.py index 924b8208b..19920f1aa 100644 --- a/tidy3d/components/data/sim_data.py +++ b/tidy3d/components/data/sim_data.py @@ -46,7 +46,7 @@ class AdjointSourceInfo(Tidy3dBaseModel): """Stores information about the adjoint sources to pass to autograd pipeline.""" - sources: tuple[SourceType, ...] = pd.Field( + sources: Tuple[annotate_type(SourceType), ...] = pd.Field( ..., title="Adjoint Sources", description="Set of processed sources to include in the adjoint simulation.", @@ -1117,6 +1117,9 @@ def process_adjoint_sources(self, adj_srcs: list[SourceType]) -> AdjointSourceIn adj_srcs, post_norm = self.process_adjoint_sources_broadband(adj_srcs) return AdjointSourceInfo(sources=adj_srcs, post_norm=post_norm, normalize_sim=True) + import pdb + + pdb.set_trace() # if several spatial ports and several frequencies, try to fit log.info("Adjoint source creation: trying multifrequency fit.") adj_srcs, post_norm = self.process_adjoint_sources_fit( diff --git a/tidy3d/web/api/autograd/autograd.py b/tidy3d/web/api/autograd/autograd.py index 2c5093ce0..b5a5eb3d8 100644 --- a/tidy3d/web/api/autograd/autograd.py +++ b/tidy3d/web/api/autograd/autograd.py @@ -731,15 +731,20 @@ def setup_adj( td.log.info("Running custom vjp (adjoint) pipeline.") + # import pdb; pdb.set_trace() + # immediately filter out any data_vjps with all 0's in the data - data_fields_vjp = { + data_fields_vjp_old = { key: get_static(value) for key, value in data_fields_vjp.items() if not np.all(value == 0.0) } + data_fields_vjp = {key: get_static(value) for key, value in data_fields_vjp.items()} # insert the raw VJP data into the .data of the original SimulationData + sim_data_vjp_old = sim_data_orig.insert_traced_fields(field_mapping=data_fields_vjp_old) sim_data_vjp = sim_data_orig.insert_traced_fields(field_mapping=data_fields_vjp) # make adjoint simulation from that SimulationData + data_vjp_paths_old = set(data_fields_vjp_old.keys()) data_vjp_paths = set(data_fields_vjp.keys()) num_monitors = len(sim_data_orig.simulation.monitors) @@ -747,10 +752,43 @@ def setup_adj( num_monitors: ] + sim_adj_old, adjoint_source_info_old = sim_data_vjp_old.make_adjoint_sim( + data_vjp_paths=data_vjp_paths_old, adjoint_monitors=adjoint_monitors + ) + sim_adj, adjoint_source_info = sim_data_vjp.make_adjoint_sim( data_vjp_paths=data_vjp_paths, adjoint_monitors=adjoint_monitors ) + # def get_current_amps(adjoint_source_info): + # vals = {} + # for key in ('Ex', 'Ey', 'Ez', 'Hx', 'Hy', 'Hz'): + # dataset = adjoint_source_info.sources[0].current_dataset + # if key in dataset.field_components: + # vals[key] = np.sum(abs(dataset.field_components[key])) + # return vals + + # amps_old = get_current_amps(adjoint_source_info_old) + # amps_new = get_current_amps(adjoint_source_info) + + # fld_mnt = td.FieldMonitor( + # size=(td.inf, td.inf, td.inf), + # freqs=adjoint_monitors[0].freqs, + # name='field', + # center=(0,0,0), + # ) + + # import tidy3d.web as web + # import matplotlib.pylab as plt + + # sim_data_old = web.run(sim_adj_old.updated_copy(monitors=[fld_mnt]), task_name='old') + # sim_data_new = web.run(sim_adj.updated_copy(monitors=[fld_mnt]), task_name='old') + + # sim_data_old.plot_field('field', 'E', 're') + # sim_data_new.plot_field('field', 'E', 're') + + # import pdb; pdb.set_trace() + td.log.info(f"Adjoint simulation created with {len(sim_adj.sources)} sources.") return sim_adj, adjoint_source_info