Skip to content

Commit

Permalink
Added tests for DART
Browse files Browse the repository at this point in the history
  • Loading branch information
AndrewHerzing committed May 31, 2024
1 parent 1799db2 commit 4160e0a
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
24 changes: 24 additions & 0 deletions tomotools/tests/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,21 @@ def test_recon_sirt_gpu(self):
assert type(rec) is tomotools.base.RecStack
assert rec.data.shape[2] == slices.data.shape[1]

def test_recon_dart_gpu(self):
stack = ds.get_needle_data(True)
slices = stack.isig[120:121, :].deepcopy()
gray_levels = [0., slices.data.max() / 2, slices.data.max()]
rec = slices.reconstruct('DART',
constrain=True,
iterations=2,
thresh=0,
cuda=True,
gray_levels=gray_levels,
dart_iterations=1)
assert type(stack) is tomotools.base.TomoStack
assert type(rec) is tomotools.base.RecStack
assert rec.data.shape[2] == slices.data.shape[1]


@pytest.mark.skipif(not astra.use_cuda(), reason="CUDA not detected")
class TestAstraSIRTGPU:
Expand Down Expand Up @@ -71,6 +86,15 @@ def test_run_sirt_cuda(self):
assert rec.data.shape[0] == slices.data.shape[2]
assert type(rec) is numpy.ndarray

def test_run_dart_cuda(self):
stack = ds.get_needle_data(True)
slices = stack.isig[120:121, :].deepcopy()
gray_levels = [0., slices.data.max() / 2, slices.data.max()]
rec = recon.run(slices, 'DART', niterations=2, cuda=False, gray_levels=gray_levels, dart_iterations=1)
assert rec.data.shape == (1, slices.data.shape[1], slices.data.shape[1])
assert rec.data.shape[0] == slices.data.shape[2]
assert type(rec) is numpy.ndarray


@pytest.mark.skipif(not astra.use_cuda(), reason="CUDA not detected")
class TestStackRegisterCUDA:
Expand Down
27 changes: 27 additions & 0 deletions tomotools/tests/test_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,24 @@ def test_recon_sirt_cpu(self):
assert type(rec) is tomotools.base.RecStack
assert rec.data.shape[2] == slices.data.shape[1]

def test_recon_dart_cpu(self):
stack = ds.get_needle_data(True)
slices = stack.isig[120:121, :].deepcopy()
gray_levels = [0., slices.data.max() / 2, slices.data.max()]
rec = slices.reconstruct('DART', iterations=2, cuda=False, gray_levels=gray_levels, dart_iterations=1, ncores=1)
assert type(stack) is tomotools.base.TomoStack
assert type(rec) is tomotools.base.RecStack
assert rec.data.shape[2] == slices.data.shape[1]

def test_recon_dart_cpu_multicore(self):
stack = ds.get_needle_data(True)
slices = stack.isig[120:122, :].deepcopy()
gray_levels = [0., slices.data.max() / 2, slices.data.max()]
rec = slices.reconstruct('DART', iterations=2, cuda=False, gray_levels=gray_levels, dart_iterations=1, ncores=1)
assert type(stack) is tomotools.base.TomoStack
assert type(rec) is tomotools.base.RecStack
assert rec.data.shape[2] == slices.data.shape[1]


class TestReconRun:
def test_run_fbp_no_cuda(self):
Expand All @@ -75,6 +93,15 @@ def test_run_sirt_no_cuda(self):
assert rec.data.shape[0] == slices.data.shape[2]
assert type(rec) is numpy.ndarray

def test_run_dart_no_cuda(self):
stack = ds.get_needle_data(True)
slices = stack.isig[120:121, :].deepcopy()
gray_levels = [0., slices.data.max() / 2, slices.data.max()]
rec = recon.run(slices, 'DART', niterations=2, cuda=False, gray_levels=gray_levels, dart_iterations=1)
assert rec.data.shape == (1, slices.data.shape[1], slices.data.shape[1])
assert rec.data.shape[0] == slices.data.shape[2]
assert type(rec) is numpy.ndarray


class TestAstraSIRTError:
def test_astra_sirt_error_cpu(self):
Expand Down

0 comments on commit 4160e0a

Please sign in to comment.