Skip to content

Commit

Permalink
Add a ReverseCoords wrapper seed policy.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 636591381
  • Loading branch information
Connectomics Team authored and copybara-github committed May 23, 2024
1 parent 207fc02 commit 371a597
Showing 1 changed file with 42 additions and 33 deletions.
75 changes: 42 additions & 33 deletions ffn/inference/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,8 @@ class PolicyPeaks2d(BaseSeedPolicy):
and peak finding to identify seed points.
"""

def __init__(self,
canvas,
min_distance=7,
threshold_abs=2.5,
sort_cmp='ascending',
**kwargs):
def __init__(self, canvas, min_distance=7, threshold_abs=2.5,
sort_cmp='ascending', **kwargs):
"""Initialize settings.
For compatibility with original version, min_distance=3, threshold_abs=0,
Expand Down Expand Up @@ -420,9 +416,9 @@ def init_coords(self):


class PolicyInvertOrigins(BaseSeedPolicy):
"""Reverse order of the seed locations used in a previous segmentation run."""

def __init__(self, canvas, corner=None, segmentation_dir=None,
**kwargs):
def __init__(self, canvas, corner=None, segmentation_dir=None, **kwargs):
super().__init__(canvas, **kwargs)
self.corner = corner
self.segmentation_dir = segmentation_dir
Expand All @@ -436,19 +432,55 @@ def init_coords(self):
in points])


class PolicyDenseSeeds(BaseSeedPolicy):
"""Dense seeds from thresholded image after optional erosion."""

def __init__(self, canvas: Any, threshold: float = 0.5, num_erosions: int = 0,
invert: bool = False, **kwargs):
super().__init__(canvas, **kwargs)
self._threshold = threshold
self._num_erosions = num_erosions
self._invert = invert

def init_coords(self):
img = self.canvas.image

x = np.array(img > self._threshold).astype(bool)
if self._invert:
x = ~x
for _ in range(self._num_erosions):
x = skimage.morphology.binary_erosion(x)
coords = np.where(x)

self.coords = np.array(coords).T


class ReverseCoords(BaseSeedPolicy):
"""Wraps another policy and just reverses the seed order."""

def __init__(self, canvas, policy_to_reverse: str, **policy_kwargs):
super().__init__(canvas)
policy_cls = globals()[policy_to_reverse]
self._policy = policy_cls(canvas, **policy_kwargs)

def init_coords(self):
self.coords = np.array(list(self._policy)[::-1])


class SequentialPolicies(BaseSeedPolicy):
"""Applies policies sequentially."""

def __init__(self, canvas, policies: Sequence[tuple[str, dict[str, Any]]],
**unused_kwargs):
**kwargs):
"""Initializes the policies.
Args:
canvas: inference Canvas object
policies: sequence of policies to chain together. Each entry is a tuple
of size two; the name of the policy, followed by a keyword dict.
**unused_kwargs: other keyword arguments.
**kwargs: other keyword arguments.
"""
del kwargs
super().__init__(canvas)
self._policies = []
for seed_policy, seed_policy_kwargs in policies:
Expand All @@ -473,26 +505,3 @@ def get_state(self, previous=False):
def set_state(self, state):
for s, policy in zip(state, self._policies):
policy.set_state(s)


class PolicyDenseSeeds(BaseSeedPolicy):
"""Dense seeds from thresholded image after optional erosion."""

def __init__(self, canvas: Any, threshold: float = 0.5, num_erosions: int = 0,
invert: bool = False, **kwargs):
super().__init__(canvas, **kwargs)
self._threshold = threshold
self._num_erosions = num_erosions
self._invert = invert

def init_coords(self):
img = self.canvas.image

x = np.array(img > self._threshold).astype(bool)
if self._invert:
x = ~x
for _ in range(self._num_erosions):
x = skimage.morphology.binary_erosion(x)
coords = np.where(x)

self.coords = np.array(coords).T

0 comments on commit 371a597

Please sign in to comment.