diff --git a/CHANGELOG.md b/CHANGELOG.md index 3ed17e02b2..40a12afb8d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - `DataArray` interpolation failure due to incorrect ordering of coordinates when interpolating with autograd tracers. +- Error in `CustomSourceTime` when evaluating at a list of times entirely outside of the range of the envelope definition times. ## [2.7.2] - 2024-08-07 diff --git a/tests/test_components/test_source.py b/tests/test_components/test_source.py index 3fbf84ddfb..e4f01a0e81 100644 --- a/tests/test_components/test_source.py +++ b/tests/test_components/test_source.py @@ -307,8 +307,20 @@ def test_custom_source_time(log_capture): atol=ATOL, ) + # all times out of range + _ = cst.amp_time([-1]) + _ = cst.amp_time(-1) + assert np.allclose(cst.amp_time([2]), np.exp(-1j * 2 * np.pi * 2 * freq0), rtol=0, atol=ATOL) + assert_log_level(log_capture, None) + vals = td.components.data.data_array.TimeDataArray([1, 2], coords=dict(t=[-1, -0.5])) + dataset = td.components.data.dataset.TimeDataset(values=vals) + cst = td.CustomSourceTime(source_time_dataset=dataset, freq0=freq0, fwidth=0.1e12) + source = td.PointDipole(center=(0, 0, 0), source_time=cst, polarization="Ex") + with AssertLogLevel(log_capture, "WARNING", contains_str="defined at times"): + sim = sim.updated_copy(sources=[source]) + # test normalization warning with AssertLogLevel(log_capture, "WARNING"): sim = sim.updated_copy(normalize_index=0) diff --git a/tidy3d/components/simulation.py b/tidy3d/components/simulation.py index e500001f6e..a05da33f5f 100644 --- a/tidy3d/components/simulation.py +++ b/tidy3d/components/simulation.py @@ -3161,6 +3161,31 @@ def _post_init_validators(self) -> None: self._validate_no_structures_pml() self._validate_tfsf_nonuniform_grid() self._validate_nonlinear_specs() + self._validate_custom_source_time() + + def _validate_custom_source_time(self): + """Warn if all simulation times are outside CustomSourceTime definition range.""" + # skip this validation if tmesh can't be computed, for example because of unloaded + # custom media + try: + _ = self.tmesh + except pydantic.ValidationError: + return + for idx, source in enumerate(self.sources): + if isinstance(source.source_time, CustomSourceTime): + if source.source_time._all_outside_range(tmesh=self.tmesh): + data_times = source.source_time.data_times + mint = np.min(data_times) + maxt = np.max(data_times) + mintmesh = np.min(self.tmesh) + maxtmesh = np.max(self.tmesh) + log.warning( + f"'CustomSourceTime' at 'sources[{idx}]' is defined at " + "times which do not include any of the 'Simulation.tmesh'. " + f"'CustomSourceTime' is defined in the time range " + f"'({mint}, {maxt})'; 'Simulation.tmesh' covers the range " + f"'({mintmesh}, {maxtmesh})'" + ) def _validate_no_structures_pml(self) -> None: """Ensure no structures terminate / have bounds inside of PML.""" diff --git a/tidy3d/components/source.py b/tidy3d/components/source.py index 5a7a245c3e..d3c222e542 100644 --- a/tidy3d/components/source.py +++ b/tidy3d/components/source.py @@ -352,6 +352,27 @@ def from_values( source_time_dataset=source_time_dataset, ) + @property + def data_times(self) -> ArrayFloat1D: + """Times of envelope definition.""" + if self.source_time_dataset is None: + return [] + data_times = self.source_time_dataset.values.coords["t"].values.squeeze() + return data_times + + def _all_outside_range(self, tmesh: ArrayFloat1D) -> bool: + """Whether all tmesh are outside range of definition.""" + # make time a numpy array for uniform handling + data_times = self.data_times + + # shift time + twidth = 1.0 / (2 * np.pi * self.fwidth) + time_shifted = tmesh - self.offset * twidth + + mask = (time_shifted < min(data_times)) | (time_shifted > max(data_times)) + + return all(mask) + def amp_time(self, time: float) -> complex: """Complex-valued source amplitude as a function of time. @@ -370,8 +391,8 @@ def amp_time(self, time: float) -> complex: return None # make time a numpy array for uniform handling - times = np.array([time] if isinstance(time, float) else time) - data_times = self.source_time_dataset.values.coords["t"].values.squeeze() + times = np.array([time] if isinstance(time, (int, float)) else time) + data_times = self.data_times # shift time twidth = 1.0 / (2 * np.pi * self.fwidth) @@ -384,12 +405,13 @@ def amp_time(self, time: float) -> complex: envelope = np.zeros(len(time_shifted), dtype=complex) values = self.source_time_dataset.values envelope[mask] = values.sel(t=time_shifted[mask], method="nearest").to_numpy() - envelope[~mask] = values.interp(t=time_shifted[~mask]).to_numpy() + if not all(mask): + envelope[~mask] = values.interp(t=time_shifted[~mask]).to_numpy() # modulation, phase, amplitude omega0 = 2 * np.pi * self.freq0 offset = np.exp(1j * self.phase) - oscillation = np.exp(-1j * omega0 * time) + oscillation = np.exp(-1j * omega0 * times) amp = self.amplitude return offset * oscillation * amp * envelope