Skip to content

Commit

Permalink
truncate adjoint filter if input shape small
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerflex committed Oct 31, 2023
1 parent c328507 commit 07a27f2
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed
- Fixed the duplication of log messages in Jupyter when `set_logging_file` is used.
- If input to circular filters in adjoint have size smaller than the diameter, instead of erroring, warn user and truncate the filter kernel accordingly.

## [2.5.0rc2] - 2023-10-30

Expand Down
14 changes: 14 additions & 0 deletions tests/test_plugins/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,6 +1473,20 @@ def test_adjoint_utils(strict_binarize):
_ = radius_penalty.evaluate(polyslab.vertices)


@pytest.mark.parametrize(
"input_size_y, log_level_expected", [(13, None), (12, "WARNING"), (11, "WARNING"), (14, None)]
)
def test_adjoint_filter_sizes(log_capture, input_size_y, log_level_expected):
"""Warn if filter size along a dim is smaller than radius."""

signal_in = np.ones((266, input_size_y))

_filter = ConicFilter(radius=0.08, design_region_dl=0.015)
_filter.evaluate(signal_in)

assert_log_level(log_capture, log_level_expected)


def test_sim_data_plot_field(use_emulated_run):
"""Test splitting of regular simulation data into user and server data."""

Expand Down
37 changes: 37 additions & 0 deletions tidy3d/plugins/adjoint/utils/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from ....components.base import Tidy3dBaseModel
from ....constants import MICROMETER
from ....log import log


class Filter(Tidy3dBaseModel, ABC):
Expand Down Expand Up @@ -59,6 +60,39 @@ def _deprecate_feature_size(cls, values):
def make_kernel(self, coords_rad: jnp.array) -> jnp.array:
"""Function to make the kernel out of a coordinate grid of radius values."""

@staticmethod
def _check_kernel_size(kernel: jnp.array, signal_in: jnp.array) -> jnp.array:
"""Make sure kernel isn't larger than signal and warn and truncate if so."""

kernel_shape = kernel.shape
input_shape = signal_in.shape

if any((k_shape > in_shape for k_shape, in_shape in zip(kernel_shape, input_shape))):

# remove some pixels from the kernel to make things right
new_kernel = kernel.copy()
for axis, (len_kernel, len_input) in enumerate(zip(kernel_shape, input_shape)):
if len_kernel > len_input:
rm_pixels_total = len_kernel - len_input
rm_pixels_edge = int(np.ceil(rm_pixels_total / 2))
indices_truncated = np.arange(rm_pixels_edge, len_kernel - rm_pixels_edge)
new_kernel = new_kernel.take(indices=indices_truncated.astype(int), axis=axis)

log.warning(
f"The filter input has shape {input_shape} whereas the "
f"kernel has shape {kernel_shape}. "
"These shapes are incompatible as the input must "
"be larger than the kernel along all dimensions. "
"The kernel will automatically be "
f"resized to {new_kernel.shape} to be less than the input shape. "
"If this is unexpected, "
"either reduce the filter 'radius' or increase the input array's size."
)

return new_kernel

return kernel

def evaluate(self, spatial_data: jnp.array) -> jnp.array:
"""Process on supplied spatial data."""

Expand All @@ -74,6 +108,9 @@ def evaluate(self, spatial_data: jnp.array) -> jnp.array:
# construct the kernel
kernel = self.make_kernel(coords_rad)

# handle when kernel is too large compared to input
kernel = self._check_kernel_size(kernel=kernel, signal_in=rho)

# normalize by the kernel operating on a spatial_data of all ones
num = jsp.signal.convolve(rho, kernel, mode="same")
den = jsp.signal.convolve(jnp.ones_like(rho), kernel, mode="same")
Expand Down

0 comments on commit 07a27f2

Please sign in to comment.