diff --git a/meshparty/skeleton.py b/meshparty/skeleton.py index dcb4e49..4594661 100644 --- a/meshparty/skeleton.py +++ b/meshparty/skeleton.py @@ -1,12 +1,15 @@ import numpy as np from meshparty import utils -from scipy import spatial, sparse +from scipy import spatial, sparse, interpolate from dataclasses import dataclass, fields, asdict, make_dataclass try: from pykdtree.kdtree import KDTree as pyKDTree except: pyKDTree = spatial.cKDTree +from copy import copy +import json + from meshparty import skeleton_io from collections.abc import Iterable from .skeleton_utils import resample_path @@ -710,8 +713,94 @@ def reroot(self, new_root): @property def distance_to_root(self): + """an an array of distances to root for each skeleton vertex "Distance to root (even if root is not in the mask)" + Returns + ------- + np.array + N length array with the distance to the root node along the skeleton. + """ + return self._rooted.distance_to_root[self.node_mask] + + def resample(self, spacing, kind='nearest'): + """ resample the skeleton to a new spacing and filter it to only include components connected to root + + Parameters + ---------- + spacing : [float] + desired edge spacing in units of vertices + + Returns + ------- + skeleton.Skeleton + a resampled skeleton, which has the parts which are not connected to root removed + resample_map: + a N long array with as many entries as vertices in the original skeleton. + the entry reflects which new skeleton vertex that vertex should be mapped to + """ + cpaths= self.cover_paths + d_to_root = self.distance_to_root + path_counter=0 + branch_d = {} + vert_list= [] + edge_list = [] + + resample_map = -np.ones(len(self.vertices), dtype=np.int32) + + for path in cpaths: + if ~np.isinf(d_to_root[path[-1]]): + # use the distance from root to parameterize the path + input_d = d_to_root[path] + # the desired distances from root are evenly spaced according to spacing + des_d = np.arange(np.min(input_d),np.max(input_d), spacing) + # setup an interpolation function based upon distance to root as input and xyz as output + fi = interpolate.interp1d(input_d, self.vertices[path,:], kind=kind, axis=0) + # use the function to interpolate the new values + new_verts = fi(des_d) + + # find the index of the old branch points in the new path + is_branch = np.isin(np.array(path), self.branch_points) + path_branch = path[is_branch] + path_branch_verts = self.vertices[path_branch, :] + tree = pyKDTree(new_verts) + map_ds, new_branch_on_path = tree.query(path_branch_verts) + new_branch_on_path+=path_counter + # create a temporary dictionary with this path's branch points + new_branch_d = {pb: nw for pb, nw in zip(path_branch, new_branch_on_path)} + # update the overall mapping dictionary + branch_d.update(new_branch_d) + + # map the entire path to the new vertices by euc distance + map_ds, path_map = tree.query(self.vertices[path,:]) + # update the mapping + resample_map[path]=path_map+path_counter + + # new edges just march down path from last vertex to first + new_edges = np.vstack([ np.arange(len(new_verts)-1,0,-1), np.arange(len(new_verts)-2,-1,-1)]).T + path_counter + # need to construct the last edge since it wasn't in the original + # find that last edge whose start point was the first vertex in the path + last_edge=self.edges[self.edges[:,0]==path[-1],:] + # for the first path there won't be an edge as the first vertex is root + if len(last_edge)==1: + # if we do have one, then we want to add an edge that is the from the first vertex + # in the path, to the new vertex (mapped through branch_d) of the edge we found + # in the original edge list + new_edges = np.vstack([new_edges, [path_counter, branch_d[last_edge[0,1]] ]]) + # collect the edges and vertices in a list + edge_list.append(new_edges) + vert_list.append(new_verts) + # increment the counter to keep track of how many vertices we have + path_counter += len(new_verts) + + # concatenate the results together + new_verts = np.vstack(vert_list) + new_edges = np.vstack(edge_list) + + # create a new skeleton + # TODO add options to update mesh mapping and vertex properties + return Skeleton(new_verts, new_edges, + root=branch_d[self.root]), resample_map def path_to_root(self, v_ind): "Path stops if it leaves masked region" diff --git a/test/skeleton_test.py b/test/skeleton_test.py index 82bb9d0..8a326f6 100644 --- a/test/skeleton_test.py +++ b/test/skeleton_test.py @@ -123,6 +123,11 @@ def test_downstream_nodes(full_cell_skeleton): assert len(sk.downstream_nodes(sk.root)) == sk.n_vertices assert len(sk.downstream_nodes(300)) == 135 +def test_resample(full_cell_skeleton): + sk = full_cell_skeleton + skn, resamp_map = sk.resample(1000) + assert(skn.n_vertices==4476) + assert(len(resamp_map)==sk.n_vertices) def test_skeleton_quality(full_cell_skeleton, full_cell_mesh, mesh_link_edges): sk = full_cell_skeleton @@ -132,3 +137,4 @@ def test_skeleton_quality(full_cell_skeleton, full_cell_mesh, mesh_link_edges): skeleton_quality.skeleton_path_quality(sk, mesh, return_path_info=True) assert len(pscore) == len(sk.cover_paths) assert np.isclose(pscore.sum(), -151.377, 0.001) +