Skip to content

Commit

Permalink
Adding DataArray validators for CustomFieldSource and CustomMedium,
Browse files Browse the repository at this point in the history
including tests.
  • Loading branch information
marc-flex committed May 27, 2024
1 parent b54df57 commit 2c27053
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 1 deletion.
35 changes: 35 additions & 0 deletions tests/test_components/test_medium.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,3 +833,38 @@ def test_coaxial_lumped_resistor():
normal_axis=1,
name="R",
)

def test_custom_medium(log_capture):
Nx, Ny, Nz, Nf = 4, 3, 1, 1
X = np.linspace(-1, 1, Nx)
Y = np.linspace(-1, 1, Ny)
Z = [0]
freqs = [2e14]
n_data = np.ones((Nx, Ny, Nz, Nf))
n_dataset = td.ScalarFieldDataArray(n_data, coords=dict(x=X, y=Y, z=Z, f=freqs))

def create_mediums(n_dataset):
## Three equivalent ways of defining custom medium for the lens

# define custom medium with n/k data
_ = td.CustomMedium.from_nk(n_dataset, interp_method="nearest")

# define custom medium with permittivity data
eps_dataset = td.ScalarFieldDataArray(n_dataset**2, coords=dict(x=X, y=Y, z=Z, f=freqs))
_ = td.CustomMedium.from_eps_raw(eps_dataset, interp_method="nearest")

# define each component of permittivity via "PermittivityDataset"
eps_xyz_dataset = td.PermittivityDataset(
eps_xx=eps_dataset, eps_yy=eps_dataset, eps_zz=eps_dataset
)
_ = td.CustomMedium(eps_dataset=eps_xyz_dataset, interp_method="nearest")

create_mediums(n_dataset=n_dataset)
assert_log_level(log_capture, None)

with pytest.raises(pydantic.ValidationError):
# repeat some entries so data cannot be interpolated
X2 = [X[0]] + list(X)
n_data2 = np.vstack((n_data[0, :, :, :].reshape(1, Ny, Nz, Nf), n_data))
n_dataset2 = td.ScalarFieldDataArray(n_data2, coords=dict(x=X2, y=Y, z=Z, f=freqs))
create_mediums(n_dataset=n_dataset2)
28 changes: 28 additions & 0 deletions tests/test_components/test_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,3 +323,31 @@ def test_custom_source_time(log_capture): # noqa: F811
dataset = td.components.data.dataset.TimeDataset(values=vals)
cst = td.CustomSourceTime(source_time_dataset=dataset, freq0=freq0, fwidth=0.1e12)
assert np.allclose(cst.amp_time([0]), [1], rtol=0, atol=ATOL)


def test_custom_field_source(log_capture):
Nx, Ny, Nz, Nf = 4, 3, 1, 1
X = np.linspace(-1, 1, Nx)
Y = np.linspace(-1, 1, Ny)
Z = [0]
freqs = [2e14]
n_data = np.ones((Nx, Ny, Nz, Nf))
n_dataset = td.ScalarFieldDataArray(n_data, coords=dict(x=X, y=Y, z=Z, f=freqs))

def make_custom_field_source(field_ds):
custom_source = td.CustomFieldSource(
center=(1, 1, 1), size=(2, 2, 0), source_time=ST, field_dataset=field_ds
)
return custom_source

field_dataset = td.FieldDataset(Ex=n_dataset, Hy=n_dataset)
make_custom_field_source(field_dataset)
assert_log_level(log_capture, None)

with pytest.raises(pydantic.ValidationError):
# repeat some entries so data cannot be interpolated
X2 = [X[0]] + list(X)
n_data2 = np.vstack((n_data[0, :, :, :].reshape(1, Ny, Nz, Nf), n_data))
n_dataset2 = td.ScalarFieldDataArray(n_data2, coords=dict(x=X2, y=Y, z=Z, f=freqs))
field_dataset = td.FieldDataset(Ex=n_dataset, Hy=n_dataset2)
make_custom_field_source(field_dataset)
43 changes: 43 additions & 0 deletions tidy3d/components/data/data_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import xarray as xr
import numpy as np
import pandas
import dask
import h5py

Expand Down Expand Up @@ -93,6 +94,48 @@ def assign_data_attrs(cls, val):
val.attrs[attr_name] = attr
return val

def _interp_validator(self, field_name: str = None) -> None:
"""Make sure we can interp()/sel() the data."""
# NOTE: this does not check every 'DataArray' by default. Instead, when required, this check can be
# called from a validator, as is the case with 'CustomMedium' and 'CustomFieldSource'.

if field_name is None:
field_name = "DataArray"

dims = self.coords.dims

for dim in dims:
# in case we encounter some /0 or /NaN we'll ignore the warnings here
with np.errstate(divide="ignore", invalid="ignore"):
# check that we can interpolate
try:
x0 = np.array(self.coords[dim][0])
self.interp({dim: x0}, method="linear")
self.interp({dim: x0}, method="nearest")
# self.interp_like(self.isel({self.dim: 0}))
except pandas.errors.InvalidIndexError as e:
raise DataError(
f"'{field_name}.interp()' fails to interpolate along {dim} which is used by the solver. "
"This may be caused, for instance, by duplicated data "
f"in this dimension (you can verify this by running "
f"'{field_name}={field_name}.drop_duplicates(dim=\"{dim}\")' "
f"and interpolate with the new '{field_name}'). "
"Plase make sure data can be interpolated."
) from e
# in case it can interpolate, try also to sel
try:
x0 = np.array(self.coords[dim][0])
self.sel({dim: x0}, method="nearest")
except pandas.errors.InvalidIndexError as e:
raise DataError(
f"'{field_name}.sel()' fails to select along {dim} which is used by the solver. "
"This may be caused, for instance, by duplicated data "
f"in this dimension (you can verify this by running "
f"'{field_name}={field_name}.drop_duplicates(dim=\"{dim}\")' "
f"and run 'sel()' with the new '{field_name}'). "
"Plase make sure 'sel()' can be used on the 'DataArray'."
) from e

@classmethod
def assign_coord_attrs(cls, val):
"""Assign the correct coordinate attributes to the :class:`.DataArray`."""
Expand Down
9 changes: 9 additions & 0 deletions tidy3d/components/medium.py
Original file line number Diff line number Diff line change
Expand Up @@ -1896,6 +1896,15 @@ def _passivity_modulation_validation(cls, val, values):
)
return val

@pd.validator("permittivity", "conductivity", always=True)
def _check_permittivity_conductivity_interpolate(cls, val, values, field):
"""Check that the custom medium 'SpatialDataArrays' can be interpolated."""

if isinstance(val, SpatialDataArray):
val._interp_validator(field.name)

return val

@cached_property
def is_isotropic(self) -> bool:
"""Check if the medium is isotropic or anisotropic."""
Expand Down
11 changes: 10 additions & 1 deletion tidy3d/components/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .types import ArrayFloat1D, Axis, PlotVal, ArrayComplex1D, TYPE_TAG_STR
from .validators import assert_plane, assert_volumetric
from .validators import warn_if_dataset_none, assert_single_freq_in_range, _assert_min_freq
from .data.dataset import FieldDataset, TimeDataset
from .data.dataset import FieldDataset, TimeDataset, ScalarFieldDataArray
from .data.validators import validate_no_nans
from .data.data_array import TimeDataArray
from .geometry.base import Box
Expand Down Expand Up @@ -792,6 +792,15 @@ def _tangential_component_defined(cls, val: FieldDataset, values: dict) -> Field
return val
raise SetupError("No tangential field found in the suppled 'field_dataset'.")

@pydantic.validator("field_dataset", always=True)
def _check_fields_interpolate(cls, val: FieldDataset) -> FieldDataset:
"""Checks whether the filds in 'field_dataset' can be interpolated."""
if isinstance(val, FieldDataset):
for name, data in val.field_components.items():
if isinstance(data, ScalarFieldDataArray):
data._interp_validator(name)
return val


""" Source current profiles defined by (1) angle or (2) desired mode. Sets theta and phi angles."""

Expand Down

0 comments on commit 2c27053

Please sign in to comment.