From d42bb2420e03ef5a35ad741a5acd6849877ab774 Mon Sep 17 00:00:00 2001 From: Marc Bolinches Date: Wed, 8 May 2024 13:10:36 +0200 Subject: [PATCH] Adding test for custom medium --- tests/test_components/test_medium.py | 38 +++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/tests/test_components/test_medium.py b/tests/test_components/test_medium.py index 4eba082b16..4c03d70203 100644 --- a/tests/test_components/test_medium.py +++ b/tests/test_components/test_medium.py @@ -5,7 +5,7 @@ import matplotlib.pyplot as plt import tidy3d as td from tidy3d.exceptions import ValidationError, SetupError -from ..utils import assert_log_level, log_capture, AssertLogLevel +from ..utils import assert_log_level, AssertLogLevel from typing import Dict MEDIUM = td.Medium() @@ -780,3 +780,39 @@ def test_lumped_resistor(): voltage_axis=2, name="R", ) + + +def test_custom_medium(log_capture): + Nx, Ny, Nz, Nf = 4, 3, 1, 1 + X = np.linspace(-1, 1, Nx) + Y = np.linspace(-1, 1, Ny) + Z = [0] + freqs = [2e14] + n_data = np.ones((Nx, Ny, Nz, Nf)) + n_dataset = td.ScalarFieldDataArray(n_data, coords=dict(x=X, y=Y, z=Z, f=freqs)) + + def create_mediums(n_dataset): + ## Three equivalent ways of defining custom medium for the lens + + # define custom medium with n/k data + _ = td.CustomMedium.from_nk(n_dataset, interp_method="nearest") + + # define custom medium with permittivity data + eps_dataset = td.ScalarFieldDataArray(n_dataset**2, coords=dict(x=X, y=Y, z=Z, f=freqs)) + _ = td.CustomMedium.from_eps_raw(eps_dataset, interp_method="nearest") + + # define each component of permittivity via "PermittivityDataset" + eps_xyz_dataset = td.PermittivityDataset( + eps_xx=eps_dataset, eps_yy=eps_dataset, eps_zz=eps_dataset + ) + _ = td.CustomMedium(eps_dataset=eps_xyz_dataset, interp_method="nearest") + + create_mediums(n_dataset=n_dataset) + assert_log_level(log_capture, None) + + with pytest.raises(pydantic.ValidationError): + # repeat some entries so data cannot be interpolated + X2 = [X[0]] + list(X) + n_data2 = np.vstack((n_data[0, :, :, :].reshape(1, Ny, Nz, Nf), n_data)) + n_dataset2 = td.ScalarFieldDataArray(n_data2, coords=dict(x=X2, y=Y, z=Z, f=freqs)) + create_mediums(n_dataset=n_dataset2)