Skip to content

Commit

Permalink
Fix CustomSourceTime with times completely outside envelope definitio…
Browse files Browse the repository at this point in the history
…n time range
  • Loading branch information
caseyflex committed Aug 13, 2024
1 parent eaf9bc6 commit 90c5441
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed
- `DataArray` interpolation failure due to incorrect ordering of coordinates when interpolating with autograd tracers.
- Error in `CustomSourceTime` when evaluating at a list of times entirely outside of the range of the envelope definition times.

## [2.7.2] - 2024-08-07

Expand Down
12 changes: 12 additions & 0 deletions tests/test_components/test_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,20 @@ def test_custom_source_time(log_capture):
atol=ATOL,
)

# all times out of range
_ = cst.amp_time([-1])
_ = cst.amp_time(-1)
assert np.allclose(cst.amp_time([2]), np.exp(-1j * 2 * np.pi * 2 * freq0), rtol=0, atol=ATOL)

assert_log_level(log_capture, None)

vals = td.components.data.data_array.TimeDataArray([1, 2], coords=dict(t=[-1, -0.5]))
dataset = td.components.data.dataset.TimeDataset(values=vals)
cst = td.CustomSourceTime(source_time_dataset=dataset, freq0=freq0, fwidth=0.1e12)
source = td.PointDipole(center=(0, 0, 0), source_time=cst, polarization="Ex")
with AssertLogLevel(log_capture, "WARNING", contains_str="defined at times"):
sim = sim.updated_copy(sources=[source])

# test normalization warning
with AssertLogLevel(log_capture, "WARNING"):
sim = sim.updated_copy(normalize_index=0)
Expand Down
25 changes: 25 additions & 0 deletions tidy3d/components/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3161,6 +3161,31 @@ def _post_init_validators(self) -> None:
self._validate_no_structures_pml()
self._validate_tfsf_nonuniform_grid()
self._validate_nonlinear_specs()
self._validate_custom_source_time()

def _validate_custom_source_time(self):
"""Warn if all simulation times are outside CustomSourceTime definition range."""
# skip this validation if tmesh can't be computed, for example because of unloaded
# custom media
try:
_ = self.tmesh
except pydantic.ValidationError:
return
for idx, source in enumerate(self.sources):
if isinstance(source.source_time, CustomSourceTime):
if source.source_time._all_outside_range(tmesh=self.tmesh):
data_times = source.source_time.data_times
mint = np.min(data_times)
maxt = np.max(data_times)
mintmesh = np.min(self.tmesh)
maxtmesh = np.max(self.tmesh)
log.warning(
f"'CustomSourceTime' at 'sources[{idx}]' is defined at "
"times which do not include any of the 'Simulation.tmesh'. "
f"'CustomSourceTime' is defined in the time range "
f"'({mint}, {maxt})'; 'Simulation.tmesh' covers the range "
f"'({mintmesh}, {maxtmesh})'"
)

def _validate_no_structures_pml(self) -> None:
"""Ensure no structures terminate / have bounds inside of PML."""
Expand Down
30 changes: 26 additions & 4 deletions tidy3d/components/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,27 @@ def from_values(
source_time_dataset=source_time_dataset,
)

@property
def data_times(self) -> ArrayFloat1D:
"""Times of envelope definition."""
if self.source_time_dataset is None:
return []
data_times = self.source_time_dataset.values.coords["t"].values.squeeze()
return data_times

def _all_outside_range(self, tmesh: ArrayFloat1D) -> bool:
"""Whether all tmesh are outside range of definition."""
# make time a numpy array for uniform handling
data_times = self.data_times

# shift time
twidth = 1.0 / (2 * np.pi * self.fwidth)
time_shifted = tmesh - self.offset * twidth

mask = (time_shifted < min(data_times)) | (time_shifted > max(data_times))

return all(mask)

def amp_time(self, time: float) -> complex:
"""Complex-valued source amplitude as a function of time.
Expand All @@ -370,8 +391,8 @@ def amp_time(self, time: float) -> complex:
return None

# make time a numpy array for uniform handling
times = np.array([time] if isinstance(time, float) else time)
data_times = self.source_time_dataset.values.coords["t"].values.squeeze()
times = np.array([time] if isinstance(time, (int, float)) else time)
data_times = self.data_times

# shift time
twidth = 1.0 / (2 * np.pi * self.fwidth)
Expand All @@ -384,12 +405,13 @@ def amp_time(self, time: float) -> complex:
envelope = np.zeros(len(time_shifted), dtype=complex)
values = self.source_time_dataset.values
envelope[mask] = values.sel(t=time_shifted[mask], method="nearest").to_numpy()
envelope[~mask] = values.interp(t=time_shifted[~mask]).to_numpy()
if not all(mask):
envelope[~mask] = values.interp(t=time_shifted[~mask]).to_numpy()

# modulation, phase, amplitude
omega0 = 2 * np.pi * self.freq0
offset = np.exp(1j * self.phase)
oscillation = np.exp(-1j * omega0 * time)
oscillation = np.exp(-1j * omega0 * times)
amp = self.amplitude

return offset * oscillation * amp * envelope
Expand Down

0 comments on commit 90c5441

Please sign in to comment.