Skip to content

Commit

Permalink
Geometry edges over field plots and fix mode solver plot return (#1628)
Browse files Browse the repository at this point in the history
Signed-off-by: Lucas Heitzmann Gabrielli <[email protected]>
  • Loading branch information
lucas-flexcompute authored and tylerflex committed May 20, 2024
1 parent 939214b commit fc47a01
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 2 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added
- `RectangularWaveguide.plot_field` optionally draws geometry edges over fields.

### Fixed
- `ModeSolver.plot_field` correctly returning the plot axes.

## [2.7.0rc2] - 2024-05-14

### Added
Expand Down
2 changes: 1 addition & 1 deletion tidy3d/plugins/mode/mode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,7 +970,7 @@ def plot_field(
"""

sim_data = self.sim_data
sim_data.plot_field(
return sim_data.plot_field(
field_monitor_name=MODE_MONITOR_NAME,
field_name=field_name,
val=val,
Expand Down
57 changes: 56 additions & 1 deletion tidy3d/plugins/waveguide/rectangular_dielectric.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import List, Any

import numpy
from matplotlib import pyplot
import pydantic.v1 as pydantic

from ...components.base import Tidy3dBaseModel, cached_property, skip_if_fields_missing
Expand All @@ -14,6 +15,7 @@
from ...components.medium import Medium, MediumType
from ...components.mode import ModeSpec
from ...components.simulation import Simulation
from ...components.viz import add_ax_if_none

from ...components.source import ModeSource, GaussianPulse
from ...components.structure import Structure
Expand Down Expand Up @@ -915,6 +917,54 @@ def plot_grid(
**kwargs,
)

@add_ax_if_none
def plot_geometry_edges(
self,
color: str = "k",
ax: Ax = None,
) -> Ax:
"""Plot the waveguide cross-section geometry edges.
Parameters
----------
color : Color to use for the geometry edges.
ax : matplotlib.axes._subplots.Axes = None
matplotlib axes to plot on, if not specified, one is created.
Returns
-------
matplotlib.axes._subplots.Axes
The supplied or created matplotlib axes.
"""
kwargs = {"color": color, "linewidth": pyplot.rcParams["grid.linewidth"]}
if self.normal_axis < self.lateral_axis:
x0 = self.origin[self.lateral_axis]
ax.axvline(x0 + self.core_thickness + self.clad_thickness, **kwargs)
ax.axvline(x0, **kwargs)
ax.axvline(x0 - self.box_thickness, **kwargs)
else:
y0 = self.origin[self.normal_axis]
ax.axhline(y0 + self.core_thickness + self.clad_thickness, **kwargs)
ax.axhline(y0, **kwargs)
ax.axhline(y0 - self.box_thickness, **kwargs)

dx = self.core_thickness * numpy.tan(self.sidewall_angle)
u = min(self.lateral_axis, self.normal_axis)
v = max(self.lateral_axis, self.normal_axis)
for x0, w in zip(self._core_starts, self.core_width):
plot_x = []
plot_y = []
for x, y in zip(
[x0 - dx, x0, x0 + w, x0 + w + dx],
[0, self.core_thickness, self.core_thickness, 0],
):
p = self._translate(x, y, 0)
plot_x.append(p[u])
plot_y.append(p[v])
ax.plot(plot_x, plot_y, linestyle="-", **kwargs)
ax.set_aspect("equal")
return ax

def plot_field(
self,
field_name: str,
Expand All @@ -924,6 +974,7 @@ def plot_field(
vmin: float = None,
vmax: float = None,
ax: Ax = None,
geometry_edges: str = None,
**sel_kwargs,
) -> Ax:
"""Plot the field for a :class:`.ModeSolverData` with :class:`.Simulation` plot overlaid.
Expand Down Expand Up @@ -951,6 +1002,7 @@ def plot_field(
inferred from the data and other keyword arguments.
ax : matplotlib.axes._subplots.Axes = None
matplotlib axes to plot on, if not specified, one is created.
geometry_edges : Optional color to use for the geometry edges overlaid on the fields.
sel_kwargs : keyword arguments used to perform ``.sel()`` selection in the monitor data.
These kwargs can select over the spatial dimensions (``x``, ``y``, ``z``),
frequency or time dimensions (``f``, ``t``) or `mode_index`, if applicable.
Expand All @@ -963,7 +1015,7 @@ def plot_field(
matplotlib.axes._subplots.Axes
The supplied or created matplotlib axes.
"""
return self.mode_solver.plot_field(
ax = self.mode_solver.plot_field(
field_name=field_name,
val=val,
eps_alpha=eps_alpha,
Expand All @@ -973,3 +1025,6 @@ def plot_field(
ax=ax,
**sel_kwargs,
)
if geometry_edges is not None:
self.plot_geometry_edges(geometry_edges, ax=ax)
return ax

0 comments on commit fc47a01

Please sign in to comment.