Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate that custom datasets can interpolate #1684

Merged
merged 1 commit into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions tests/test_components/test_medium.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,3 +833,39 @@ def test_coaxial_lumped_resistor():
normal_axis=1,
name="R",
)


def test_custom_medium(log_capture): # noqa: F811
Nx, Ny, Nz, Nf = 4, 3, 1, 1
X = np.linspace(-1, 1, Nx)
Y = np.linspace(-1, 1, Ny)
Z = [0]
freqs = [2e14]
n_data = np.ones((Nx, Ny, Nz, Nf))
n_dataset = td.ScalarFieldDataArray(n_data, coords=dict(x=X, y=Y, z=Z, f=freqs))

def create_mediums(n_dataset):
## Three equivalent ways of defining custom medium for the lens

# define custom medium with n/k data
_ = td.CustomMedium.from_nk(n_dataset, interp_method="nearest")

# define custom medium with permittivity data
eps_dataset = td.ScalarFieldDataArray(n_dataset**2, coords=dict(x=X, y=Y, z=Z, f=freqs))
_ = td.CustomMedium.from_eps_raw(eps_dataset, interp_method="nearest")

# define each component of permittivity via "PermittivityDataset"
eps_xyz_dataset = td.PermittivityDataset(
eps_xx=eps_dataset, eps_yy=eps_dataset, eps_zz=eps_dataset
)
_ = td.CustomMedium(eps_dataset=eps_xyz_dataset, interp_method="nearest")

create_mediums(n_dataset=n_dataset)
assert_log_level(log_capture, None)

with pytest.raises(pydantic.ValidationError):
# repeat some entries so data cannot be interpolated
X2 = [X[0]] + list(X)
n_data2 = np.vstack((n_data[0, :, :, :].reshape(1, Ny, Nz, Nf), n_data))
n_dataset2 = td.ScalarFieldDataArray(n_data2, coords=dict(x=X2, y=Y, z=Z, f=freqs))
create_mediums(n_dataset=n_dataset2)
28 changes: 28 additions & 0 deletions tests/test_components/test_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,3 +323,31 @@ def test_custom_source_time(log_capture): # noqa: F811
dataset = td.components.data.dataset.TimeDataset(values=vals)
cst = td.CustomSourceTime(source_time_dataset=dataset, freq0=freq0, fwidth=0.1e12)
assert np.allclose(cst.amp_time([0]), [1], rtol=0, atol=ATOL)


def test_custom_field_source(log_capture):
Nx, Ny, Nz, Nf = 4, 3, 1, 1
X = np.linspace(-1, 1, Nx)
Y = np.linspace(-1, 1, Ny)
Z = [0]
freqs = [2e14]
n_data = np.ones((Nx, Ny, Nz, Nf))
n_dataset = td.ScalarFieldDataArray(n_data, coords=dict(x=X, y=Y, z=Z, f=freqs))

def make_custom_field_source(field_ds):
custom_source = td.CustomFieldSource(
center=(1, 1, 1), size=(2, 2, 0), source_time=ST, field_dataset=field_ds
)
return custom_source

field_dataset = td.FieldDataset(Ex=n_dataset, Hy=n_dataset)
make_custom_field_source(field_dataset)
assert_log_level(log_capture, None)

with pytest.raises(pydantic.ValidationError):
# repeat some entries so data cannot be interpolated
X2 = [X[0]] + list(X)
n_data2 = np.vstack((n_data[0, :, :, :].reshape(1, Ny, Nz, Nf), n_data))
n_dataset2 = td.ScalarFieldDataArray(n_data2, coords=dict(x=X2, y=Y, z=Z, f=freqs))
field_dataset = td.FieldDataset(Ex=n_dataset, Hy=n_dataset2)
make_custom_field_source(field_dataset)
43 changes: 43 additions & 0 deletions tidy3d/components/data/data_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import xarray as xr
import numpy as np
import pandas
import dask
import h5py

Expand Down Expand Up @@ -93,6 +94,48 @@ def assign_data_attrs(cls, val):
val.attrs[attr_name] = attr
return val

def _interp_validator(self, field_name: str = None) -> None:
"""Make sure we can interp()/sel() the data."""
# NOTE: this does not check every 'DataArray' by default. Instead, when required, this check can be
# called from a validator, as is the case with 'CustomMedium' and 'CustomFieldSource'.

if field_name is None:
field_name = "DataArray"

dims = self.coords.dims

for dim in dims:
# in case we encounter some /0 or /NaN we'll ignore the warnings here
with np.errstate(divide="ignore", invalid="ignore"):
# check that we can interpolate
try:
x0 = np.array(self.coords[dim][0])
self.interp({dim: x0}, method="linear")
self.interp({dim: x0}, method="nearest")
# self.interp_like(self.isel({self.dim: 0}))
except pandas.errors.InvalidIndexError as e:
raise DataError(
f"'{field_name}.interp()' fails to interpolate along {dim} which is used by the solver. "
"This may be caused, for instance, by duplicated data "
f"in this dimension (you can verify this by running "
f"'{field_name}={field_name}.drop_duplicates(dim=\"{dim}\")' "
f"and interpolate with the new '{field_name}'). "
"Plase make sure data can be interpolated."
) from e
# in case it can interpolate, try also to sel
try:
x0 = np.array(self.coords[dim][0])
self.sel({dim: x0}, method="nearest")
except pandas.errors.InvalidIndexError as e:
raise DataError(
marc-flex marked this conversation as resolved.
Show resolved Hide resolved
f"'{field_name}.sel()' fails to select along {dim} which is used by the solver. "
"This may be caused, for instance, by duplicated data "
f"in this dimension (you can verify this by running "
f"'{field_name}={field_name}.drop_duplicates(dim=\"{dim}\")' "
f"and run 'sel()' with the new '{field_name}'). "
"Plase make sure 'sel()' can be used on the 'DataArray'."
) from e

@classmethod
def assign_coord_attrs(cls, val):
"""Assign the correct coordinate attributes to the :class:`.DataArray`."""
Expand Down
9 changes: 9 additions & 0 deletions tidy3d/components/medium.py
Original file line number Diff line number Diff line change
Expand Up @@ -1896,6 +1896,15 @@ def _passivity_modulation_validation(cls, val, values):
)
return val

@pd.validator("permittivity", "conductivity", always=True)
def _check_permittivity_conductivity_interpolate(cls, val, values, field):
"""Check that the custom medium 'SpatialDataArrays' can be interpolated."""

if isinstance(val, SpatialDataArray):
val._interp_validator(field.name)

return val

@cached_property
def is_isotropic(self) -> bool:
"""Check if the medium is isotropic or anisotropic."""
Expand Down
11 changes: 10 additions & 1 deletion tidy3d/components/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .types import ArrayFloat1D, Axis, PlotVal, ArrayComplex1D, TYPE_TAG_STR
from .validators import assert_plane, assert_volumetric
from .validators import warn_if_dataset_none, assert_single_freq_in_range, _assert_min_freq
from .data.dataset import FieldDataset, TimeDataset
from .data.dataset import FieldDataset, TimeDataset, ScalarFieldDataArray
from .data.validators import validate_no_nans
from .data.data_array import TimeDataArray
from .geometry.base import Box
Expand Down Expand Up @@ -792,6 +792,15 @@ def _tangential_component_defined(cls, val: FieldDataset, values: dict) -> Field
return val
raise SetupError("No tangential field found in the suppled 'field_dataset'.")

@pydantic.validator("field_dataset", always=True)
def _check_fields_interpolate(cls, val: FieldDataset) -> FieldDataset:
"""Checks whether the filds in 'field_dataset' can be interpolated."""
if isinstance(val, FieldDataset):
for name, data in val.field_components.items():
if isinstance(data, ScalarFieldDataArray):
marc-flex marked this conversation as resolved.
Show resolved Hide resolved
data._interp_validator(name)
return val


""" Source current profiles defined by (1) angle or (2) desired mode. Sets theta and phi angles."""

Expand Down
Loading