diff --git a/tests/test_components/test_medium.py b/tests/test_components/test_medium.py index c04fdef4d8..d875d4c2e6 100644 --- a/tests/test_components/test_medium.py +++ b/tests/test_components/test_medium.py @@ -833,3 +833,38 @@ def test_coaxial_lumped_resistor(): normal_axis=1, name="R", ) + +def test_custom_medium(log_capture): + Nx, Ny, Nz, Nf = 4, 3, 1, 1 + X = np.linspace(-1, 1, Nx) + Y = np.linspace(-1, 1, Ny) + Z = [0] + freqs = [2e14] + n_data = np.ones((Nx, Ny, Nz, Nf)) + n_dataset = td.ScalarFieldDataArray(n_data, coords=dict(x=X, y=Y, z=Z, f=freqs)) + + def create_mediums(n_dataset): + ## Three equivalent ways of defining custom medium for the lens + + # define custom medium with n/k data + _ = td.CustomMedium.from_nk(n_dataset, interp_method="nearest") + + # define custom medium with permittivity data + eps_dataset = td.ScalarFieldDataArray(n_dataset**2, coords=dict(x=X, y=Y, z=Z, f=freqs)) + _ = td.CustomMedium.from_eps_raw(eps_dataset, interp_method="nearest") + + # define each component of permittivity via "PermittivityDataset" + eps_xyz_dataset = td.PermittivityDataset( + eps_xx=eps_dataset, eps_yy=eps_dataset, eps_zz=eps_dataset + ) + _ = td.CustomMedium(eps_dataset=eps_xyz_dataset, interp_method="nearest") + + create_mediums(n_dataset=n_dataset) + assert_log_level(log_capture, None) + + with pytest.raises(pydantic.ValidationError): + # repeat some entries so data cannot be interpolated + X2 = [X[0]] + list(X) + n_data2 = np.vstack((n_data[0, :, :, :].reshape(1, Ny, Nz, Nf), n_data)) + n_dataset2 = td.ScalarFieldDataArray(n_data2, coords=dict(x=X2, y=Y, z=Z, f=freqs)) + create_mediums(n_dataset=n_dataset2) diff --git a/tests/test_components/test_source.py b/tests/test_components/test_source.py index 819a26f160..c92f4bfb7c 100644 --- a/tests/test_components/test_source.py +++ b/tests/test_components/test_source.py @@ -323,3 +323,31 @@ def test_custom_source_time(log_capture): # noqa: F811 dataset = td.components.data.dataset.TimeDataset(values=vals) cst = td.CustomSourceTime(source_time_dataset=dataset, freq0=freq0, fwidth=0.1e12) assert np.allclose(cst.amp_time([0]), [1], rtol=0, atol=ATOL) + + +def test_custom_field_source(log_capture): + Nx, Ny, Nz, Nf = 4, 3, 1, 1 + X = np.linspace(-1, 1, Nx) + Y = np.linspace(-1, 1, Ny) + Z = [0] + freqs = [2e14] + n_data = np.ones((Nx, Ny, Nz, Nf)) + n_dataset = td.ScalarFieldDataArray(n_data, coords=dict(x=X, y=Y, z=Z, f=freqs)) + + def make_custom_field_source(field_ds): + custom_source = td.CustomFieldSource( + center=(1, 1, 1), size=(2, 2, 0), source_time=ST, field_dataset=field_ds + ) + return custom_source + + field_dataset = td.FieldDataset(Ex=n_dataset, Hy=n_dataset) + make_custom_field_source(field_dataset) + assert_log_level(log_capture, None) + + with pytest.raises(pydantic.ValidationError): + # repeat some entries so data cannot be interpolated + X2 = [X[0]] + list(X) + n_data2 = np.vstack((n_data[0, :, :, :].reshape(1, Ny, Nz, Nf), n_data)) + n_dataset2 = td.ScalarFieldDataArray(n_data2, coords=dict(x=X2, y=Y, z=Z, f=freqs)) + field_dataset = td.FieldDataset(Ex=n_dataset, Hy=n_dataset2) + make_custom_field_source(field_dataset) diff --git a/tidy3d/components/data/data_array.py b/tidy3d/components/data/data_array.py index d12674aa27..1000abf866 100644 --- a/tidy3d/components/data/data_array.py +++ b/tidy3d/components/data/data_array.py @@ -5,6 +5,7 @@ import xarray as xr import numpy as np +import pandas import dask import h5py @@ -93,6 +94,48 @@ def assign_data_attrs(cls, val): val.attrs[attr_name] = attr return val + def _interp_validator(self, field_name: str = None) -> None: + """Make sure we can interp()/sel() the data.""" + # NOTE: this does not check every 'DataArray' by default. Instead, when required, this check can be + # called from a validator, as is the case with 'CustomMedium' and 'CustomFieldSource'. + + if field_name is None: + field_name = "DataArray" + + dims = self.coords.dims + + for dim in dims: + # in case we encounter some /0 or /NaN we'll ignore the warnings here + with np.errstate(divide="ignore", invalid="ignore"): + # check that we can interpolate + try: + x0 = np.array(self.coords[dim][0]) + self.interp({dim: x0}, method="linear") + self.interp({dim: x0}, method="nearest") + # self.interp_like(self.isel({self.dim: 0})) + except pandas.errors.InvalidIndexError as e: + raise DataError( + f"'{field_name}.interp()' fails to interpolate along {dim} which is used by the solver. " + "This may be caused, for instance, by duplicated data " + f"in this dimension (you can verify this by running " + f"'{field_name}={field_name}.drop_duplicates(dim=\"{dim}\")' " + f"and interpolate with the new '{field_name}'). " + "Plase make sure data can be interpolated." + ) from e + # in case it can interpolate, try also to sel + try: + x0 = np.array(self.coords[dim][0]) + self.sel({dim: x0}, method="nearest") + except pandas.errors.InvalidIndexError as e: + raise DataError( + f"'{field_name}.sel()' fails to select along {dim} which is used by the solver. " + "This may be caused, for instance, by duplicated data " + f"in this dimension (you can verify this by running " + f"'{field_name}={field_name}.drop_duplicates(dim=\"{dim}\")' " + f"and run 'sel()' with the new '{field_name}'). " + "Plase make sure 'sel()' can be used on the 'DataArray'." + ) from e + @classmethod def assign_coord_attrs(cls, val): """Assign the correct coordinate attributes to the :class:`.DataArray`.""" diff --git a/tidy3d/components/medium.py b/tidy3d/components/medium.py index 0326410b03..b46cb3babd 100644 --- a/tidy3d/components/medium.py +++ b/tidy3d/components/medium.py @@ -1896,6 +1896,15 @@ def _passivity_modulation_validation(cls, val, values): ) return val + @pd.validator("permittivity", "conductivity", always=True) + def _check_permittivity_conductivity_interpolate(cls, val, values, field): + """Check that the custom medium 'SpatialDataArrays' can be interpolated.""" + + if isinstance(val, SpatialDataArray): + val._interp_validator(field.name) + + return val + @cached_property def is_isotropic(self) -> bool: """Check if the medium is isotropic or anisotropic.""" diff --git a/tidy3d/components/source.py b/tidy3d/components/source.py index cd26f63882..5261860dd2 100644 --- a/tidy3d/components/source.py +++ b/tidy3d/components/source.py @@ -16,7 +16,7 @@ from .types import ArrayFloat1D, Axis, PlotVal, ArrayComplex1D, TYPE_TAG_STR from .validators import assert_plane, assert_volumetric from .validators import warn_if_dataset_none, assert_single_freq_in_range, _assert_min_freq -from .data.dataset import FieldDataset, TimeDataset +from .data.dataset import FieldDataset, TimeDataset, ScalarFieldDataArray from .data.validators import validate_no_nans from .data.data_array import TimeDataArray from .geometry.base import Box @@ -792,6 +792,15 @@ def _tangential_component_defined(cls, val: FieldDataset, values: dict) -> Field return val raise SetupError("No tangential field found in the suppled 'field_dataset'.") + @pydantic.validator("field_dataset", always=True) + def _check_fields_interpolate(cls, val: FieldDataset) -> FieldDataset: + """Checks whether the filds in 'field_dataset' can be interpolated.""" + if isinstance(val, FieldDataset): + for name, data in val.field_components.items(): + if isinstance(data, ScalarFieldDataArray): + data._interp_validator(name) + return val + """ Source current profiles defined by (1) angle or (2) desired mode. Sets theta and phi angles."""