Skip to content

Commit

Permalink
Added docstrings for DART functions
Browse files Browse the repository at this point in the history
  • Loading branch information
AndrewHerzing committed May 31, 2024
1 parent 4160e0a commit 6794329
Showing 1 changed file with 63 additions and 17 deletions.
80 changes: 63 additions & 17 deletions tomotools/recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@

def run_alg(sino, iters, sino_id, alg_id, rec_id):
"""
Run reconstruction algorithm.
Run FBP or SIRT reconstruction algorithm.
Args
----------
sion : NumPy array
sino : NumPy array
Sinogram of shape (nangles, ny)
iters : int
Number of iterations for the SIRT reconstruction
Expand All @@ -48,16 +48,52 @@ def run_alg(sino, iters, sino_id, alg_id, rec_id):
return astra.data2d.get(rec_id)


def run_dart(sinogram, iters, dart_iterations, p,
alg, proj_id, mask_id, rec_id, sino_id,
thresholds, gray_levels, thickness, ny):
astra.data2d.store(sino_id, sinogram)
def run_dart(sino, iters, dart_iters, p,
alg_id, proj_id, mask_id, rec_id, sino_id,
thresholds, gray_levels):
"""
Run FBP or SIRT reconstruction algorithm.
Args
----------
sino : NumPy array
Sinogram of shape (nangles, ny)
iters : int
Number of iterations for the SART reconstruction
dart_iters : int
Number of iterations for the DART reconstruction
p : float
Probability for free pixel determination
alg_id : int
ASTRA algorithm identity
proj_id : int
ASTRA projector identity
mask_id : boolean
ASTRA mask identity
rec_id : boolean
ASTRA reconstruction identity
sino_id : int
ASTRA sinogram identity
thresholds : list or NumPy array
Thresholds for DART reconstruction
gray_levels : list or NumPy array
Gray levels for DART reconstruction
Returns
----------
Numpy array
Reconstruction of input sinogram
"""
thickness, ny = astra.data2d.get(rec_id).shape
astra.data2d.store(sino_id, sino)
astra.data2d.store(rec_id, np.zeros([thickness, ny]))
astra.data2d.store(mask_id, np.ones([thickness, ny]))
astra.algorithm.run(alg, iters)
astra.algorithm.run(alg_id, iters)
curr_rec = astra.data2d.get(rec_id)
dart_rec = copy.deepcopy(curr_rec)
for j in range(dart_iterations):
for j in range(dart_iters):
segmented = dart_segment(dart_rec, thresholds, gray_levels)
boundary = get_dart_boundaries(segmented)

Expand All @@ -76,17 +112,17 @@ def run_dart(sinogram, iters, dart_iterations, p,
fixed_rec = copy.deepcopy(dart_rec)
fixed_rec[free_idx[0], free_idx[1]] = 0
_, fixed_sino = astra.creators.create_sino(fixed_rec, proj_id)
free_sino = sinogram - fixed_sino
free_sino = sino - fixed_sino

# Run SART reconstruction on free sinogram with free pixel mask
astra.data2d.store(rec_id, dart_rec)
astra.data2d.store(mask_id, free)
astra.data2d.store(sino_id, free_sino)
astra.algorithm.run(alg, iters)
astra.algorithm.run(alg_id, iters)
dart_rec = astra.data2d.get(rec_id)

# Smooth reconstruction
if j < dart_iterations - 1:
if j < dart_iters - 1:
smooth = gaussian_filter(dart_rec, sigma=1)
curr_rec[free_idx[0], free_idx[1]] = smooth[free_idx[0], free_idx[1]]
else:
Expand Down Expand Up @@ -117,6 +153,18 @@ def run(stack, method, niterations=20, constrain=None, thresh=0, cuda=None, thic
cuda : boolean
If True, use the CUDA-accelerated Astra algorithms. Otherwise,
use the CPU-based algorithms
thickness : int
Limit for the height of the reconstruction
ncores : int
Number of cores to use for multithreaded CPU-based reconstructions
filter : str
Filter to use for filtered backprojection
gray_levels : list or NumPy array
Gray levels for DART reconstruction
dart_iterations : int
Number of DART iterations
p : float
Probability for setting free pixels in DART reconstruction
Returns
----------
Expand Down Expand Up @@ -200,8 +248,7 @@ def run(stack, method, niterations=20, constrain=None, thresh=0, cuda=None, thic
astra.data2d.store(rec_id, np.zeros([thickness, ny]))
astra.data2d.store(mask_id, np.ones([thickness, ny]))
rec[i, :, :] = run_dart(sinogram, niterations, dart_iterations, p,
alg, proj_id, mask_id, rec_id, sino_id, thresholds,
gray_levels, thickness, ny)
alg, proj_id, mask_id, rec_id, sino_id, thresholds, gray_levels)
else:
if ncores is None:
ncores = min(nx, int(0.9 * mp.cpu_count()))
Expand Down Expand Up @@ -256,16 +303,15 @@ def run(stack, method, niterations=20, constrain=None, thresh=0, cuda=None, thic
if ncores == 1:
for i in tqdm.tqdm(range(0, nx)):
rec[i] = run_dart(stack.data[:, :, i], niterations, dart_iterations, p,
alg, proj_id, mask_id, rec_id, sino_id, thresholds,
gray_levels, thickness, ny)
alg, proj_id, mask_id, rec_id, sino_id, thresholds, gray_levels)
else:
logger.info("Using %s CPU cores to reconstruct %s slices" % (ncores, nx))
with mp.Pool(ncores) as pool:
for i, result in enumerate(
pool.starmap(run_dart,
[(stack.data[:, :, i], niterations, dart_iterations, p,
alg, proj_id, mask_id, rec_id, sino_id, thresholds,
gray_levels, thickness, ny) for i in range(0, nx)],)):
alg, proj_id, mask_id, rec_id, sino_id, thresholds, gray_levels)
for i in range(0, nx)],)):
rec[i] = result
astra.clear()
return rec
Expand Down

0 comments on commit 6794329

Please sign in to comment.