Skip to content

Commit

Permalink
Parametrization.fit now returns a class object of type Parametrizatio…
Browse files Browse the repository at this point in the history
…n whose sample function can be invoked directly
  • Loading branch information
maurerv committed Oct 26, 2023
1 parent 833e0c5 commit 0b6fbd5
Showing 1 changed file with 64 additions and 20 deletions.
84 changes: 64 additions & 20 deletions colabseg/parametrization.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,13 @@ class Sphere(Parametrization):
"""
Parametrize a point cloud as sphere.
"""
@staticmethod
def fit(positions: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:

def __init__(self, radius, center):
self.radius = radius
self.center = center

@classmethod
def fit(cls, positions: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""
Fit an sphere to a set of 3D points.
Expand Down Expand Up @@ -98,28 +103,33 @@ def fit(positions: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ (sol[2] * sol[2] / 4.0)
+ sol[3]
)
return cls(
radius = radius,
center = np.array([sol[0] / 2.0, sol[1] / 2.0, sol[2] / 2.0])
)

return radius, np.array([sol[0] / 2.0, sol[1] / 2.0, sol[2] / 2.0])

@staticmethod
def sample(n_samples: int, radius: np.ndarray, center: np.ndarray) -> np.ndarray:
def sample(self, n_samples: int, radius: np.ndarray = None,
center: np.ndarray = None) -> np.ndarray:
"""
Samples points from the surface of a sphere.
Parameters
----------
n_samples : int
Number of samples to draw
radius : np.ndarray
radius : np.ndarray, optional
Radius of the sphere
center : np.ndarray
center : np.ndarray, optional
Center of the sphere along each axis
Returns
-------
np.ndarray
Sampled points.
"""
center = self.center if center is None else center
radius = self.radius if radius is None else radius

sp = np.linspace(0, 2.0 * np.pi, num=n_samples)
x0, y0, z0 = center
nx = sp.shape[0]
Expand All @@ -137,8 +147,13 @@ class Ellipsoid(Parametrization):
Parametrize a point cloud as ellipsoid.
"""

@staticmethod
def fit(positions) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
def __init__(self, radii, center, orientations):
self.radius = radii
self.center = center
self.orientations = orientations

@classmethod
def fit(cls, positions) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Fit an ellipsoid to a set of 3D points.
Expand Down Expand Up @@ -207,11 +222,15 @@ def fit(positions) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:

center = np.add(positions_mean, center)

return radii, center, evecs
return cls(
radii = radii,
center = center,
orientations = evecs
)

@staticmethod
def sample(
n_samples: int, radii: np.ndarray, center: np.ndarray, orientations: np.ndarray
self, n_samples: int, radii: np.ndarray = None,
center: np.ndarray= None, orientations: np.ndarray= None
) -> np.ndarray:
"""
Samples points from the surface of an ellisoid.
Expand All @@ -232,6 +251,9 @@ def sample(
np.ndarray
Sampled points.
"""
radii = self.radii if radii is None else radii
center = self.center if center is None else center
orientations = self.orientations if orientations is None else orientations

phi = np.random.uniform(0, 2 * np.pi, n_samples)
costheta = np.random.uniform(-1, 1, n_samples)
Expand All @@ -253,8 +275,15 @@ class Cylinder(Parametrization):
Parametrize a point cloud as cylinder.
"""

@staticmethod
def __init__(self, centers, angles, radius, height):
self.centers = centers
self.angles = angles
self.radius = radius
self.height = height

@classmethod
def fit(
cls,
positions: np.ndarray, initial_parameters: np.ndarray = None
) -> Tuple[np.ndarray, np.ndarray, float]:
"""
Expand Down Expand Up @@ -315,15 +344,20 @@ def error_function(p: np.ndarray, positions: np.ndarray) -> np.ndarray:
radius = parameters[4]
height = positions[:, 2].max() - positions[:, 2].min()

return centers, angles, radius, height
return cls(
centers = centers,
angles = angles,
radius = radius,
height = height
)

@staticmethod
def sample(
self,
n_samples: int,
centers: np.ndarray,
angles: np.ndarray,
radius: float,
height: float,
centers: np.ndarray = None,
angles: np.ndarray = None,
radius: float = None,
height: float = None,
) -> np.ndarray:
"""
Sample points from the surface of a cylinder.
Expand All @@ -346,6 +380,10 @@ def sample(
np.ndarray
Array of sampled points from the cylinder surface.
"""
centers = self.centers if centers is None else centers
angles = self.angles if angles is None else angles
radius = self.radius if radius is None else radius
height = self.height if height is None else height

Xc, Yc = centers

Expand All @@ -359,3 +397,9 @@ def sample(
z = h

return np.column_stack((x, y, z))

PARAMETRIZATION_TYPE = {
"sphere" : Sphere,
"ellipsoid" : Ellipsoid,
"cylinder" : Cylinder
}

0 comments on commit 0b6fbd5

Please sign in to comment.