diff --git a/tests/test_plugins/smatrix/terminal_component_modeler_def.py b/tests/test_plugins/smatrix/terminal_component_modeler_def.py index a622faf38..5d281640d 100644 --- a/tests/test_plugins/smatrix/terminal_component_modeler_def.py +++ b/tests/test_plugins/smatrix/terminal_component_modeler_def.py @@ -295,7 +295,7 @@ def make_port(center, direction, type, name) -> Union[CoaxialLumpedPort, WavePor radius=mean_radius, num_points=41, normal_axis=2, - clockwise=False, + clockwise=direction != "+", ), ) return port diff --git a/tests/test_plugins/smatrix/test_terminal_component_modeler.py b/tests/test_plugins/smatrix/test_terminal_component_modeler.py index 010e47437..4836a2832 100644 --- a/tests/test_plugins/smatrix/test_terminal_component_modeler.py +++ b/tests/test_plugins/smatrix/test_terminal_component_modeler.py @@ -32,6 +32,12 @@ def run_component_modeler(monkeypatch, modeler: TerminalComponentModeler): "_compute_F", lambda matrix: 1.0 / (2.0 * np.sqrt(np.abs(matrix) + 1e-4)), ) + monkeypatch.setattr( + TerminalComponentModeler, + "_check_port_impedance_sign", + lambda self, Z_numpy: (), + ) + s_matrix = modeler._construct_smatrix() return s_matrix @@ -547,3 +553,12 @@ def test_port_source_snapped_to_PML(tmp_path): with pytest.raises(SetupError): modeler._shift_value_signed(port) + + +def test_wave_port_validate_current_integral(tmp_path): + """Checks that the current integral direction validator runs correctly.""" + modeler = make_coaxial_component_modeler( + path_dir=str(tmp_path), port_types=(WavePort, WavePort) + ) + with pytest.raises(pydantic.ValidationError): + _ = modeler.updated_copy(direction="-", path="ports/0/") diff --git a/tidy3d/plugins/microwave/custom_path_integrals.py b/tidy3d/plugins/microwave/custom_path_integrals.py index 251e97a57..5fcb98156 100644 --- a/tidy3d/plugins/microwave/custom_path_integrals.py +++ b/tidy3d/plugins/microwave/custom_path_integrals.py @@ -6,13 +6,14 @@ import numpy as np import pydantic.v1 as pd +import shapely import xarray as xr from ...components.base import cached_property from ...components.data.data_array import FreqDataArray, FreqModeDataArray, TimeDataArray from ...components.data.monitor_data import FieldData, FieldTimeData, ModeSolverData from ...components.geometry.base import Geometry -from ...components.types import ArrayFloat2D, Ax, Axis, Bound, Coordinate +from ...components.types import ArrayFloat2D, Ax, Axis, Bound, Coordinate, Direction from ...components.viz import add_ax_if_none from ...constants import MICROMETER, fp_eps from ...exceptions import DataError, SetupError @@ -396,3 +397,16 @@ def plot( arrowprops=ARROW_CURRENT, ) return ax + + @cached_property + def sign(self) -> Direction: + """Uses the ordering of the vertices to determine the direction of the current flow.""" + linestr = shapely.LineString(coordinates=self.vertices) + is_ccw = shapely.is_ccw(linestr) + # Invert statement when the vertices are given as (x, z) + if self.axis == 1: + is_ccw = not is_ccw + if is_ccw: + return "+" + else: + return "-" diff --git a/tidy3d/plugins/smatrix/component_modelers/terminal.py b/tidy3d/plugins/smatrix/component_modelers/terminal.py index 424b6a745..ec806f299 100644 --- a/tidy3d/plugins/smatrix/component_modelers/terminal.py +++ b/tidy3d/plugins/smatrix/component_modelers/terminal.py @@ -16,7 +16,7 @@ from ....components.types import Ax from ....components.viz import add_ax_if_none, equal_aspect from ....constants import C_0, OHM -from ....exceptions import ValidationError +from ....exceptions import Tidy3dError, ValidationError from ....web.api.container import BatchData from ..ports.base_lumped import AbstractLumpedPort from ..ports.base_terminal import AbstractTerminalPort, TerminalPortDataArray @@ -174,7 +174,7 @@ def sim_dict(self) -> Dict[str, Simulation]: @cached_property def _source_time(self): - """Helper to create a time domain pulse for the frequeny range of interest.""" + """Helper to create a time domain pulse for the frequency range of interest.""" freq0 = np.mean(self.freqs) fdiff = max(self.freqs) - min(self.freqs) fwidth = max(fdiff, freq0 * FWIDTH_FRAC) @@ -238,6 +238,15 @@ def port_VI(port_out: AbstractTerminalPort, sim_data: SimulationData): Z_numpy = port_impedances.transpose(*PortDataArray._dims).values.reshape( (len(self.freqs), len(port_names), 1) ) + + # Check to make sure sign is consistent for all impedance values + self._check_port_impedance_sign(Z_numpy) + + # Check for negative real part of port impedance and flip the V and Z signs accordingly + negative_real_Z = np.real(Z_numpy) < 0 + V_numpy = np.where(negative_real_Z, -V_numpy, V_numpy) + Z_numpy = np.where(negative_real_Z, -Z_numpy, Z_numpy) + F_numpy = TerminalComponentModeler._compute_F(Z_numpy) # Equation 4.67 - Pozar - Microwave Engineering 4ed @@ -385,3 +394,15 @@ def _set_port_data_array_attributes(data_array: PortDataArray) -> PortDataArray: """Helper to set additional metadata for ``PortDataArray``.""" data_array.name = "Z0" return data_array.assign_attrs(units=OHM, long_name="characteristic impedance") + + def _check_port_impedance_sign(self, Z_numpy: np.ndarray): + """Sanity check for consistent sign of real part of Z for each port across all frequencies.""" + for port_idx in range(Z_numpy.shape[1]): + port_Z = Z_numpy[:, port_idx, 0] + signs = np.sign(np.real(port_Z)) + if not np.all(signs == signs[0]): + raise Tidy3dError( + f"Inconsistent sign of real part of Z detected for port {port_idx}. " + "If you received this error, please create an issue in the Tidy3D " + "github repository." + ) diff --git a/tidy3d/plugins/smatrix/ports/wave.py b/tidy3d/plugins/smatrix/ports/wave.py index 28a5d759f..f9c5e9692 100644 --- a/tidy3d/plugins/smatrix/ports/wave.py +++ b/tidy3d/plugins/smatrix/ports/wave.py @@ -15,7 +15,11 @@ from ....components.source import GaussianPulse, ModeSource, ModeSpec from ....components.types import Bound, Direction, FreqArray from ....exceptions import ValidationError -from ...microwave import CurrentIntegralTypes, ImpedanceCalculator, VoltageIntegralTypes +from ...microwave import ( + CurrentIntegralTypes, + ImpedanceCalculator, + VoltageIntegralTypes, +) from ...mode import ModeSolver from .base_terminal import AbstractTerminalPort @@ -185,3 +189,19 @@ def _check_voltage_or_current(cls, val, values): "At least one of 'voltage_integral' or 'current_integral' must be provided." ) return val + + @pd.validator("current_integral", always=True) + def validate_current_integral_sign(cls, val, values): + """ + Validate that the sign of ``current_integral`` matches the port direction. + """ + if val is None: + return val + + direction = values.get("direction") + name = values.get("name") + if val.sign != direction: + raise ValidationError( + f"'current_integral' sign must match the '{name}' direction '{direction}'." + ) + return val