diff --git a/skeletor/skeletonize/__init__.py b/skeletor/skeletonize/__init__.py index 43fb9a4..9bfaec5 100644 --- a/skeletor/skeletonize/__init__.py +++ b/skeletor/skeletonize/__init__.py @@ -28,7 +28,8 @@ | function | speed | robust | radii [^2] | mesh map [^3] | description | | ------------------------------------------- | :---: | :----: | :--------: | :-----------: | ---------------------------------------------------| -| `skeletor.skeletonize.by_wavefront()` | +++ | ++ | yes | yes | works well for tubular meshes | +| `skeletor.skeletonize.by_wavefront()` | ++++ | ++ | yes | yes | works well for tubular meshes | +| `skeletor.skeletonize.by_wavefront_exact()` | +++ | ++ | yes | no | works well for tubular meshes | | `skeletor.skeletonize.by_vertex_clusters()` | ++ | + | no | yes | best with contracted meshes [^1] | | `skeletor.skeletonize.by_teasar()` | + | ++ | no | yes | works on mesh surface | | `skeletor.skeletonize.by_tangent_ball()` | ++ | 0 | yes | yes | works with mesh normals | @@ -47,9 +48,10 @@ from .edge_collapse import * from .vertex_cluster import * from .wave import * +from .wave2 import * from .teasar import * from .tangent_ball import * __docformat__ = "numpy" -__all__ = ['by_teasar', 'by_wavefront', 'by_vertex_clusters', +__all__ = ['by_teasar', 'by_wavefront', 'by_wavefront_exact', 'by_vertex_clusters', 'by_edge_collapse', 'by_tangent_ball'] diff --git a/skeletor/skeletonize/teasar.py b/skeletor/skeletonize/teasar.py index 4842c3c..0d4f647 100644 --- a/skeletor/skeletonize/teasar.py +++ b/skeletor/skeletonize/teasar.py @@ -98,7 +98,7 @@ def by_teasar(mesh, inv_dist, min_length=None, root=None, progress=True): with tqdm(desc='Invalidating', total=len(G.vs), disable=not progress, leave=False) as pbar: - for cc in sorted(G.clusters(), key=len, reverse=True): + for cc in sorted(G.connected_components(), key=len, reverse=True): # Make a subgraph for this connected component SG = G.subgraph(cc) cc = np.array(cc) diff --git a/skeletor/skeletonize/wave.py b/skeletor/skeletonize/wave.py index 7c1bbb7..5c4df9c 100644 --- a/skeletor/skeletonize/wave.py +++ b/skeletor/skeletonize/wave.py @@ -142,6 +142,7 @@ def by_wavefront(mesh, else: weights = np.linalg.norm(node_centers[el[:, 0]] - node_centers[el[:, 1]], axis=1) + tree = G.spanning_tree(weights=1 / weights) # Create a directed acyclic and hierarchical graph @@ -195,7 +196,7 @@ def _cast_waves(mesh, waves=1, origins=None, step_size=1, # Go over each connected component with tqdm(desc='Skeletonizing', total=len(G.vs), disable=not progress) as pbar: - for cc in G.clusters(): + for cc in G.connected_components(): # Make a subgraph for this connected component SG = G.subgraph(cc) cc = np.array(cc) @@ -241,7 +242,7 @@ def _cast_waves(mesh, waves=1, origins=None, step_size=1, this_dist = this_wave == i ix = np.where(this_dist)[0] SG2 = SG.subgraph(ix) - for cc2 in SG2.clusters(): + for cc2 in SG2.connected_components(): this_verts = cc[ix[cc2]] this_center = mesh.vertices[this_verts].mean(axis=0) this_radius = cdist(this_center.reshape(1, -1), mesh.vertices[this_verts]) diff --git a/skeletor/skeletonize/wave2.py b/skeletor/skeletonize/wave2.py new file mode 100644 index 0000000..eeb9e41 --- /dev/null +++ b/skeletor/skeletonize/wave2.py @@ -0,0 +1,290 @@ +# This script is part of skeletor (http://www.github.com/navis-org/skeletor). +# Copyright (C) 2018 Philipp Schlegel +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. + +import igraph as ig +import numpy as np +import pandas as pd + +from tqdm.auto import tqdm, trange + +from ..utilities import make_trimesh +from .base import Skeleton + + +def by_wavefront_exact(mesh, step_size, origins=None, radius_agg="mean", progress=True): + """Skeletonize a mesh using wave fronts. + + This is the _exact_ version of the `by_wavefront` function meaning + that each wave front moves exactly the given distance (see `step_size`) + along the mesh, instead of hoping from vertex to vertex. This is + computationally more expensive but also more accurate. + + Parameters + ---------- + mesh : mesh obj + The mesh to be skeletonize. Can an object that has + ``.vertices`` and ``.faces`` properties (e.g. a + trimesh.Trimesh) or a tuple ``(vertices, faces)`` or a + dictionary ``{'vertices': vertices, 'faces': faces}``. + step_size : float | int (>0) + The distance each the wave fronts move along the mesh + in each step. + origins : int | list of ints, optional + Vertex ID(s) where the wave(s) are initialized. If there + is no origin for a given connected component, will fall + back to semi-random origin. + radius_agg : "mean" | "median" | "max" | "min" | "percentile75" | "percentile25" + Function used to aggregate radii over sample (i.e. the + vertices forming a ring that we collapse to its center). + progress : bool + If True, will show progress bar. + + Returns + ------- + skeletor.Skeleton + Holds results of the skeletonization and enables quick + visualization. + + """ + agg_map = { + "mean": np.mean, + "max": np.max, + "min": np.min, + "median": np.median, + "percentile75": lambda x: np.percentile(x, 75), + "percentile25": lambda x: np.percentile(x, 25), + } + assert radius_agg in agg_map, f'Unknown `radius_agg`: "{radius_agg}"' + rad_agg_func = agg_map[radius_agg] + + mesh = make_trimesh(mesh, validate=False) + + centers_final, radii_final, parents = _cast_waves( + mesh, + step_size=step_size, + origins=origins, + rad_agg_func=rad_agg_func, + progress=progress, + ) + + # Map radii for individual vertices to the collapsed nodes + # Using pandas is the fastest way here + swc = pd.DataFrame() + swc["node_id"] = np.arange(0, len(centers_final)) + swc["parent_id"] = parents + swc["x"] = centers_final[:, 0] + swc["y"] = centers_final[:, 1] + swc["z"] = centers_final[:, 2] + swc["radius"] = radii_final + + return Skeleton(swc=swc, mesh=mesh, method="wavefront_exact") + + +def _cast_waves(mesh, step_size, origins=None, rad_agg_func=np.mean, progress=True): + """Cast waves across mesh.""" + if not isinstance(origins, type(None)): + if isinstance(origins, int): + origins = [origins] + elif not isinstance(origins, (set, list)): + raise TypeError( + "`origins` must be vertex ID (int) or list " + f'thereof, got "{type(origins)}"' + ) + origins = np.asarray(origins).astype(int) + else: + origins = np.array([]) + + # Step size must be positive + if step_size < 0: + raise ValueError("`step_size` must be > 0") + + # Generate Graph (must be undirected) + G = ig.Graph(edges=mesh.edges_unique, directed=False) + G.es["weight"] = mesh.edges_unique_length + + # Prepare empty array to fill with centers + centers = [] + radii = [] + data = [] + + # Go over each connected component + with tqdm( + desc="Skeletonizing", total=len(G.vs), disable=not progress, leave=False + ) as pbar: + for k, cc in enumerate(G.connected_components()): + # Make a subgraph for this connected component + SG = G.subgraph(cc) + cc = np.array(cc) + + # Select seeds according to the number of waves + pot_seeds = np.arange(len(cc)) + np.random.seed(1985) # make seeds predictable + # See if we can use any origins + if len(origins): + # Get those origins in this cc + in_cc = np.isin(origins, cc) + if any(in_cc): + # Map origins into cc + cc_map = dict(zip(cc, np.arange(0, len(cc)))) + seed = np.array([cc_map[o] for o in origins[in_cc]])[0] + else: + seed = np.random.choice(pot_seeds) + else: + seed = np.random.choice(pot_seeds) + + # Get the distance between the seed and all other nodes + dist = np.array( + SG.distances(source=seed, target=None, mode="all", weights="weight") + )[0] + + # What's the max distance we can reach? + mx = dist[dist < float("inf")].max() + + # To keep track of which vertices we have already processed and can be + # safely ignored + is_inside_vert = np.zeros(len(dist)).astype(bool) + + # Collect groups + for i, d in enumerate(np.arange(0, mx, step_size)): + # Inner verts are all vertices that are within the current distance + # (minus those that we have already processed) + inside_verts = np.where((dist <= d) & ~is_inside_vert)[0] + + # To get the outer edge, we need to (1) find the neighbors of the inner edge + neighbors = [ + np.array(nn) for nn in SG.neighborhood(vertices=inside_verts) + ] + # (2) only keep those whose distance is above the current distance (i.e. those going outwards) + outer_verts = [nn[dist[nn] > d] for nn in neighbors] + + # Inside vertices that are not part of the inner ring(s) have no neighbors outside + is_inside_edge = np.array([len(ov) > 0 for ov in outer_verts]) + inner_verts = inside_verts[is_inside_edge] + outer_verts = [ + ov for ov, is_iv in zip(outer_verts, is_inside_edge) if is_iv + ] + + # To avoid doing the same work twice, we will mark inner vertices + is_inside_vert[inside_verts[~is_inside_edge]] = True + + # For each edge between inner-vertices and their outer neighbors calculate the + # exact position for the exact step_size distance + in_out_edges = np.array( + [ + [iv, ov] + for iv, ovs in zip(inner_verts, outer_verts) + for ov in ovs + ] + ) + in_pos = mesh.vertices[cc[in_out_edges[:, 0]]] + out_pos = mesh.vertices[cc[in_out_edges[:, 1]]] + + # The distance between the inner and outer vertices + in_out_dist = in_out_dist = np.sqrt( + ((in_pos - out_pos) ** 2).sum(axis=1) + ) + + # For each in_out_edge the remaining distance + d_left = d - dist[in_out_edges[:, 0]] + + # Move along the edges until we hit the target distance + new_verts = ( + in_pos + (out_pos - in_pos) * (d_left / in_out_dist)[:, None] + ) + + # To group vertices into rings we need to first determine which vertices + # are part of the same ring. We do this by checking which outer vertices + # are in connected components + outer_verts_unique = np.unique(np.concatenate(outer_verts)) + for cc2 in SG.subgraph(outer_verts_unique).connected_components(): + # Translate this connected components back to the vertex IDs in SG + outer_verts_cc2 = outer_verts_unique[cc2] + + # Get the new vertices that are part of this connected component + this_new_verts = new_verts[ + np.isin(in_out_edges[:, 1], outer_verts_cc2) + ] + + # Get the center of this connected component + this_center = this_new_verts.mean(axis=0) + + # Calculate the radius of this connected component + this_radius = rad_agg_func( + np.sqrt((this_new_verts - this_center) ** 2).sum(axis=1) + ) + + centers.append(this_center) # Store the center + radii.append(this_radius) # Store the radius + data.append( + (k, i, cc[outer_verts_cc2[0]]) + ) # Track the connected component, the step and a vertex from the outer ring for this center + + # pbar.update(len(cc)) + # Update progress bar based on the number of vertices we have invalidated + pbar.update((~is_inside_edge).sum()) + + centers = np.vstack(centers) + radii = np.array(radii) + data = np.array(data) + + # Next we have to reconstruct the parent-child relationships + # For each center, we know _a_ vertex that is part of the (outer) ring + # With that information we can ask, for each center, which center in the + # previous step is closest to it (as per distance along the mesh) + parents = np.full(len(centers), fill_value=-1, dtype=int) + + # This maps the step and the vertex ID to the index in the centers array + step_id_map = {(step, id): i for i, (step, id) in enumerate(data[:, 1:])} + + tree = G.spanning_tree() + + for i in trange( + 1, data[:, 1].max(), desc="Connecting", disable=not progress, leave=False + ): + is_this_step = data[:, 1] == i # All centers in this step + is_prev_step = data[:, 1] == i - 1 # All centers in the previous step + + # If there is only one center in the previous step, we can just go on and connect these + if is_prev_step.sum() == 1: + parents[is_this_step] = np.where(is_prev_step)[0][0] + continue + + this_step_track_verts = np.unique( + data[is_this_step, 2] + ) # All track-vertices in this step + prev_step_track_verts = np.unique( + data[is_prev_step, 2] + ) # All track-vertices in the previous step + + # Get distances between current and previous track vertices + # Note to self: this step takes up about 60% of the time at the moment + # There should be a way to speed this up. + d = np.array( + tree.distances( + source=this_step_track_verts, + target=prev_step_track_verts, + mode="all", + weights="weight", + ) + ) + + # Get the closest previous track-vertex for each current track-vertex + closest_prev_track_verts = prev_step_track_verts[d.argmin(axis=1)] + + for this, prev in zip(this_step_track_verts, closest_prev_track_verts): + parents[is_this_step & (data[:, 2] == this)] = step_id_map[(i - 1, prev)] + + return centers, radii, parents