Skip to content

Commit

Permalink
small improvements to wavefront_exact:
Browse files Browse the repository at this point in the history
1. Simplify original tree before trying to connect nodes
2. Add progress bar for the connecting step
  • Loading branch information
schlegelp committed Aug 6, 2024
1 parent a703164 commit 9b8a81c
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions skeletor/skeletonize/wave2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np
import pandas as pd

from tqdm.auto import tqdm
from tqdm.auto import tqdm, trange

from ..utilities import make_trimesh
from .base import Skeleton
Expand Down Expand Up @@ -121,7 +121,9 @@ def _cast_waves(mesh, step_size, origins=None, rad_agg_func=np.mean, progress=Tr
data = []

# Go over each connected component
with tqdm(desc="Skeletonizing", total=len(G.vs), disable=not progress) as pbar:
with tqdm(
desc="Skeletonizing", total=len(G.vs), disable=not progress, leave=False
) as pbar:
for k, cc in enumerate(G.connected_components()):
# Make a subgraph for this connected component
SG = G.subgraph(cc)
Expand Down Expand Up @@ -247,7 +249,11 @@ def _cast_waves(mesh, step_size, origins=None, rad_agg_func=np.mean, progress=Tr
# This maps the step and the vertex ID to the index in the centers array
step_id_map = {(step, id): i for i, (step, id) in enumerate(data[:, 1:])}

for i in range(1, data[:, 1].max()):
tree = G.spanning_tree()

for i in trange(
1, data[:, 1].max(), desc="Connecting", disable=not progress, leave=False
):
is_this_step = data[:, 1] == i # All centers in this step
is_prev_step = data[:, 1] == i - 1 # All centers in the previous step

Expand All @@ -264,8 +270,10 @@ def _cast_waves(mesh, step_size, origins=None, rad_agg_func=np.mean, progress=Tr
) # All track-vertices in the previous step

# Get distances between current and previous track vertices
# Note to self: this step takes up about 60% of the time at the moment
# There should be a way to speed this up.
d = np.array(
G.distances(
tree.distances(
source=this_step_track_verts,
target=prev_step_track_verts,
mode="all",
Expand Down

0 comments on commit 9b8a81c

Please sign in to comment.