Skip to content

Commit

Permalink
Add class-based filters
Browse files Browse the repository at this point in the history
  • Loading branch information
yaugenst-flex committed Aug 29, 2024
1 parent da70f10 commit a48c05a
Showing 1 changed file with 135 additions and 24 deletions.
159 changes: 135 additions & 24 deletions tidy3d/plugins/autograd/invdes/filters.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,133 @@
from functools import partial
from typing import Callable, Tuple, Union
import abc
from functools import lru_cache, partial
from typing import Callable, Iterable, Tuple, Union

import numpy as np
from numpy.typing import NDArray
import pydantic as pd

from tidy3d.components.base import Tidy3dBaseModel

from ..functions import convolve
from ..types import KernelType, PaddingType
from ..utilities import get_kernel_size_px, make_kernel


class AbstractFilter(Tidy3dBaseModel, abc.ABC):
"""A filter class for creating and applying convolution filters.
Parameters
----------
kernel_size : Tuple[pd.PositiveInt, ...]
Size of the kernel in pixels for each dimension.
normalize : bool = True
Whether to normalize the kernel so that it sums to 1.
padding : PaddingType = "reflect"
The padding mode to use.
"""

kernel_size: Tuple[pd.PositiveInt, ...] = pd.Field(
..., description="Size of the kernel in pixels for each dimension."
)
normalize: bool = pd.Field(
True, description="Whether to normalize the kernel so that it sums to 1."
)
padding: PaddingType = pd.Field("reflect", description="The padding mode to use.")

@classmethod
def from_radius_dl(
cls, radius: Union[float, Tuple[float, ...]], dl: Union[float, Tuple[float, ...]], **kwargs
) -> "AbstractFilter":
"""Create a filter from radius and grid spacing.
Parameters
----------
radius : Union[float, Tuple[float, ...]]
The radius of the kernel. Can be a scalar or a tuple.
dl : Union[float, Tuple[float, ...]]
The grid spacing. Can be a scalar or a tuple.
**kwargs
Additional keyword arguments to pass to the filter constructor.
Returns
-------
AbstractFilter
An instance of the filter.
"""
kernel_size = get_kernel_size_px(radius=radius, dl=dl)
return cls(kernel_size, **kwargs)

@staticmethod
@abc.abstractmethod
def get_kernel(size_px: Iterable[int], normalize: bool) -> np.ndarray:
"""Get the kernel for the filter.
Parameters
----------
size_px : Iterable[int]
Size of the kernel in pixels for each dimension.
normalize : bool
Whether to normalize the kernel so that it sums to 1.
Returns
-------
np.ndarray
The kernel.
"""
...

def __call__(self, array: np.ndarray) -> np.ndarray:
"""Apply the filter to an input array.
Parameters
----------
array : np.ndarray
The input array to filter.
Returns
-------
np.ndarray
The filtered array.
"""
original_shape = array.shape
squeezed_array = np.squeeze(array)
size_px = self.kernel_size
if len(size_px) != squeezed_array.ndim:
size_px *= squeezed_array.ndim
kernel = self.get_kernel(size_px, self.normalize)
convolved_array = convolve(squeezed_array, kernel, padding=self.padding)
return np.reshape(convolved_array, original_shape)


class ConicFilter(AbstractFilter):
"""A conic filter for creating and applying convolution filters."""

@staticmethod
@lru_cache(maxsize=1)
def get_kernel(size_px: Iterable[int], normalize: bool) -> np.ndarray:
"""Get the conic kernel.
See Also
--------
:func:`~filters.AbstractFilter.get_kernel` For full method documentation.
"""
return make_kernel(kernel_type="conic", size=size_px, normalize=normalize)


class CircularFilter(AbstractFilter):
"""A circular filter for creating and applying convolution filters."""

@staticmethod
@lru_cache(maxsize=1)
def get_kernel(size_px: Iterable[int], normalize: bool) -> np.ndarray:
"""Get the circular kernel.
See Also
--------
:func:`~filters.AbstractFilter.get_kernel` For full method documentation.
"""
return make_kernel(kernel_type="circular", size=size_px, normalize=normalize)


def _get_kernel_size(
radius: Union[float, Tuple[float, ...]],
dl: Union[float, Tuple[float, ...]],
Expand Down Expand Up @@ -52,7 +171,7 @@ def make_filter(
normalize: bool = True,
padding: PaddingType = "reflect",
filter_type: KernelType,
) -> Callable:
) -> Callable[[np.ndarray], np.ndarray]:
"""Create a filter function based on the specified kernel type and size.
Parameters
Expand All @@ -72,42 +191,34 @@ def make_filter(
Returns
-------
function
Callable[[np.ndarray], np.ndarray]
A function that applies the created filter to an input array.
"""
_kernel = {}

def _filter(array: NDArray) -> NDArray:
kernel_size = _get_kernel_size(radius, dl, size_px)
original_shape = array.shape
squeezed_array = np.squeeze(array)
kernel_size = _get_kernel_size(radius, dl, size_px)

if squeezed_array.ndim not in _kernel:
ks = kernel_size
if len(ks) != squeezed_array.ndim:
ks *= squeezed_array.ndim
_kernel[squeezed_array.ndim] = make_kernel(
kernel_type=filter_type, size=ks, normalize=normalize
)

convolved_array = convolve(squeezed_array, _kernel[squeezed_array.ndim], padding=padding)
return np.reshape(convolved_array, original_shape)
if filter_type == "conic":
filter_class = ConicFilter
elif filter_type == "circular":
filter_class = CircularFilter
else:
raise ValueError(f"Unsupported filter_type: {filter_type}")

return _filter
filter_instance = filter_class(kernel_size=kernel_size, normalize=normalize, padding=padding)
return filter_instance


make_conic_filter = partial(make_filter, filter_type="conic")
make_conic_filter.__doc__ = """make_filter() with a default filter_type value of `conic`.
See Also
--------
make_filter : Function to create a filter based on the specified kernel type and size.
:func:`~filters.make_filter` : Function to create a filter based on the specified kernel type and size.
"""

make_circular_filter = partial(make_filter, filter_type="circular")
make_circular_filter.__doc__ = """make_filter() with a default filter_type value of `circular`.
See Also
--------
make_filter : Function to create a filter based on the specified kernel type and size.
:func:`~filters.make_filter` : Function to create a filter based on the specified kernel type and size.
"""

0 comments on commit a48c05a

Please sign in to comment.