Skip to content

Commit

Permalink
Fixed wavefield kernels and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
daurer committed Jan 26, 2025
1 parent ec5bd9e commit d8ac247
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 90 deletions.
1 change: 1 addition & 0 deletions ptypy/accelerate/cuda_common/ob_update_ML_wavefield.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ extern "C"

probe += pa[0] * E * F + pa[1] * F + pa[2];
obj += oa[0] * H * I + oa[1] * I + oa[2];
obj_fln += oa[0] * H * I + oa[1] * I + oa[2];

assert(oa[0] * H * I + oa[1] * I + oa[2] + (B - 1) * I + C - 1 < G * H * I);

Expand Down
1 change: 1 addition & 0 deletions ptypy/accelerate/cuda_common/pr_update_ML_wavefield.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ extern "C"
const int* ea = addr + 6 + bid * addr_stride;

probe += pa[0] * E * F + pa[1] * F + pa[2];
probe_fln += pa[0] * E * F + pa[1] * F + pa[2];
obj += oa[0] * H * I + oa[1] * I + oa[2];

assert(oa[0] * H * I + oa[1] * I + oa[2] + (B - 1) * I + C - 1 < G * H * I);
Expand Down
2 changes: 1 addition & 1 deletion templates/misc/moonflower_ML_cupy_fft_smoothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
# attach a reconstrucion engine
p.engines = u.Param()
p.engines.engine00 = u.Param()
p.engines.engine00.name = 'ML_cupy'
p.engines.engine00.name = 'ML_serial'
p.engines.engine00.numiter = 300
p.engines.engine00.numiter_contiguous = 5
p.engines.engine00.reg_del2 = True # Whether to use a Gaussian prior (smoothing) regularizer
Expand Down
63 changes: 32 additions & 31 deletions test/accelerate_tests/base_tests/po_update_kernel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def test_ob_update_ML(self):
def test_pr_update_ML_wavefield(self):
# setup
addr, object_array, object_array_denominator, probe, exit_wave, probe_denominator = self.prepare_arrays()
probe_wavefield = np.empty_like(probe, dtype=FLOAT_TYPE)
probe_wavefield = np.zeros_like(probe, dtype=FLOAT_TYPE)

# test
POUK = PoUpdateKernel()
Expand All @@ -269,26 +269,27 @@ def test_pr_update_ML_wavefield(self):
err_msg="The probe has not been updated as expected")

# assert
expected_probe_wavefield = np.array([[[136., 138.4375, 136., 138.4375, 136.],
[138.4375, 136., 138.4375, 136., 138.4375],
[136., 138.4375, 136., 138.4375, 136.],
[138.4375, 136., 138.4375, 136., 138.4375],
[136., 138.4375, 136., 138.4375, 136.]],

[[138.4375, 136., 138.4375, 136., 138.4375],
[136., 138.4375, 136., 138.4375, 136.],
[138.4375, 136., 138.4375, 136., 138.4375],
[136., 138.4375, 136., 138.4375, 136.],
[138.4375, 136., 138.4375, 136., 138.4375]]],
dtype=FLOAT_TYPE)
expected_probe_wavefield = np.array(
[[[136., 136., 136., 136., 136.],
[136., 136., 136., 136., 136.],
[136., 136., 136., 136., 136.],
[136., 136., 136., 136., 136.],
[136., 136., 136., 136., 136.]],

[[136., 136., 136., 136., 136.],
[136., 136., 136., 136., 136.],
[136., 136., 136., 136., 136.],
[136., 136., 136., 136., 136.],
[136., 136., 136., 136., 136.]]],
dtype=FLOAT_TYPE)
np.testing.assert_array_equal(probe_wavefield, expected_probe_wavefield,
err_msg="The probe wavefield has not been updated as expected")

def test_ob_update_ML_wavefield(self):
# setup
addr, object_array, object_array_denominator, probe, exit_wave, probe_denominator = self.prepare_arrays()
object_array_wavefield = np.empty_like(object_array, dtype=FLOAT_TYPE)

object_array_wavefield = np.zeros_like(object_array, dtype=FLOAT_TYPE)
# test
POUK = PoUpdateKernel()
POUK.allocate() # this doesn't do anything, but is the call pattern.
Expand Down Expand Up @@ -317,22 +318,22 @@ def test_ob_update_ML_wavefield(self):

# assert
expected_object_array_wavefield = np.array(
[[[10., 22.4375, 20., 22.4375, 20., 12.4375, 0.],
[22.4375, 40., 42.4375, 40., 42.4375, 20., 2.4375],
[20., 42.4375, 40., 42.4375, 40., 22.4375, 0.],
[22.4375, 40., 42.4375, 40., 42.4375, 20., 2.4375],
[20., 42.4375, 40., 42.4375, 40., 22.4375, 0.],
[12.4375, 20., 22.4375, 20., 22.4375, 10., 2.4375],
[0., 2.4375, 0., 2.4375, 0., 2.4375, 0.]],

[[12.4375, 20., 22.4375, 20., 22.4375, 10., 2.4375],
[20., 42.4375, 40., 42.4375, 40., 22.4375, 0.],
[22.4375, 40., 42.4375, 40., 42.4375, 20., 2.4375],
[20., 42.4375, 40., 42.4375, 40., 22.4375, 0.],
[22.4375, 40., 42.4375, 40., 42.4375, 20., 2.4375],
[10., 22.4375, 20., 22.4375, 20., 12.4375, 0.],
[2.4375, 0., 2.4375, 0., 2.4375, 0., 2.4375]]],
dtype=FLOAT_TYPE)
[[[10., 20., 20., 20., 20., 10., 0.],
[20., 40., 40., 40., 40., 20., 0.],
[20., 40., 40., 40., 40., 20., 0.],
[20., 40., 40., 40., 40., 20., 0.],
[20., 40., 40., 40., 40., 20., 0.],
[10., 20., 20., 20., 20., 10., 0.],
[ 0., 0., 0., 0., 0., 0., 0.]],

[[10., 20., 20., 20., 20., 10., 0.],
[20., 40., 40., 40., 40., 20., 0.],
[20., 40., 40., 40., 40., 20., 0.],
[20., 40., 40., 40., 40., 20., 0.],
[20., 40., 40., 40., 40., 20., 0.],
[10., 20., 20., 20., 20., 10., 0.],
[ 0., 0., 0., 0., 0., 0., 0.]]],
dtype=FLOAT_TYPE)
np.testing.assert_array_equal(object_array_wavefield, expected_object_array_wavefield,
err_msg="The object array wavefield has not been updated as expected")

Expand Down
59 changes: 30 additions & 29 deletions test/accelerate_tests/cuda_cupy_tests/po_update_kernel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ def pr_update_ML_wavefield_tester(self, atomics=False):
setup
'''
addr, object_array, object_array_denominator, probe, exit_wave, probe_denominator = self.prepare_arrays()
probe_wavefield = cp.asarray(np.empty_like(probe, dtype=FLOAT_TYPE))
probe_wavefield = cp.asarray(np.zeros_like(probe, dtype=FLOAT_TYPE))
'''
test
'''
Expand Down Expand Up @@ -682,18 +682,19 @@ def pr_update_ML_wavefield_tester(self, atomics=False):
np.testing.assert_array_equal(probe.get(), expected_probe,
err_msg="The probe has not been updated as expected")

expected_probe_wavefield = np.array([[[136., 138.4375, 136., 138.4375, 136.],
[138.4375, 136., 138.4375, 136., 138.4375],
[136., 138.4375, 136., 138.4375, 136.],
[138.4375, 136., 138.4375, 136., 138.4375],
[136., 138.4375, 136., 138.4375, 136.]],

[[138.4375, 136., 138.4375, 136., 138.4375],
[136., 138.4375, 136., 138.4375, 136.],
[138.4375, 136., 138.4375, 136., 138.4375],
[136., 138.4375, 136., 138.4375, 136.],
[138.4375, 136., 138.4375, 136., 138.4375]]],
dtype=FLOAT_TYPE)
expected_probe_wavefield = np.array(
[[[136., 136., 136., 136., 136.],
[136., 136., 136., 136., 136.],
[136., 136., 136., 136., 136.],
[136., 136., 136., 136., 136.],
[136., 136., 136., 136., 136.]],

[[136., 136., 136., 136., 136.],
[136., 136., 136., 136., 136.],
[136., 136., 136., 136., 136.],
[136., 136., 136., 136., 136.],
[136., 136., 136., 136., 136.]]],
dtype=FLOAT_TYPE)
np.testing.assert_array_equal(probe_wavefield.get(), expected_probe_wavefield,
err_msg="The probe wavefield has not been updated as expected")

Expand All @@ -708,7 +709,7 @@ def ob_update_ML_wavefield_tester(self, atomics=True):
setup
'''
addr, object_array, object_array_denominator, probe, exit_wave, probe_denominator = self.prepare_arrays()
object_array_wavefield = cp.asarray(np.empty_like(object_array, dtype=FLOAT_TYPE))
object_array_wavefield = cp.asarray(np.zeros_like(object_array, dtype=FLOAT_TYPE))
'''
test
'''
Expand Down Expand Up @@ -743,21 +744,21 @@ def ob_update_ML_wavefield_tester(self, atomics=True):
err_msg="The object array has not been updated as expected")

expected_object_array_wavefield = np.array(
[[[10., 22.4375, 20., 22.4375, 20., 12.4375, 0.],
[22.4375, 40., 42.4375, 40., 42.4375, 20., 2.4375],
[20., 42.4375, 40., 42.4375, 40., 22.4375, 0.],
[22.4375, 40., 42.4375, 40., 42.4375, 20., 2.4375],
[20., 42.4375, 40., 42.4375, 40., 22.4375, 0.],
[12.4375, 20., 22.4375, 20., 22.4375, 10., 2.4375],
[0., 2.4375, 0., 2.4375, 0., 2.4375, 0.]],

[[12.4375, 20., 22.4375, 20., 22.4375, 10., 2.4375],
[20., 42.4375, 40., 42.4375, 40., 22.4375, 0.],
[22.4375, 40., 42.4375, 40., 42.4375, 20., 2.4375],
[20., 42.4375, 40., 42.4375, 40., 22.4375, 0.],
[22.4375, 40., 42.4375, 40., 42.4375, 20., 2.4375],
[10., 22.4375, 20., 22.4375, 20., 12.4375, 0.],
[2.4375, 0., 2.4375, 0., 2.4375, 0., 2.4375]]],
[[[10., 20., 20., 20., 20., 10., 0.],
[20., 40., 40., 40., 40., 20., 0.],
[20., 40., 40., 40., 40., 20., 0.],
[20., 40., 40., 40., 40., 20., 0.],
[20., 40., 40., 40., 40., 20., 0.],
[10., 20., 20., 20., 20., 10., 0.],
[ 0., 0., 0., 0., 0., 0., 0.]],

[[10., 20., 20., 20., 20., 10., 0.],
[20., 40., 40., 40., 40., 20., 0.],
[20., 40., 40., 40., 40., 20., 0.],
[20., 40., 40., 40., 40., 20., 0.],
[20., 40., 40., 40., 40., 20., 0.],
[10., 20., 20., 20., 20., 10., 0.],
[ 0., 0., 0., 0., 0., 0., 0.]]],
dtype=FLOAT_TYPE)
np.testing.assert_array_equal(object_array_wavefield.get(), expected_object_array_wavefield,
err_msg="The object array wavefield has not been updated as expected")
Expand Down
59 changes: 30 additions & 29 deletions test/accelerate_tests/cuda_pycuda_tests/po_update_kernel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ def pr_update_ML_wavefield_tester(self, atomics=False):
setup
'''
addr, object_array, object_array_denominator, probe, exit_wave, probe_denominator = self.prepare_arrays()
probe_wavefield = gpuarray.to_gpu(np.empty_like(probe, dtype=FLOAT_TYPE))
probe_wavefield = gpuarray.to_gpu(np.zeros_like(probe, dtype=FLOAT_TYPE))
'''
test
'''
Expand Down Expand Up @@ -682,18 +682,19 @@ def pr_update_ML_wavefield_tester(self, atomics=False):
np.testing.assert_array_equal(probe.get(), expected_probe,
err_msg="The probe has not been updated as expected")

expected_probe_wavefield = np.array([[[136., 138.4375, 136., 138.4375, 136.],
[138.4375, 136., 138.4375, 136., 138.4375],
[136., 138.4375, 136., 138.4375, 136.],
[138.4375, 136., 138.4375, 136., 138.4375],
[136., 138.4375, 136., 138.4375, 136.]],

[[138.4375, 136., 138.4375, 136., 138.4375],
[136., 138.4375, 136., 138.4375, 136.],
[138.4375, 136., 138.4375, 136., 138.4375],
[136., 138.4375, 136., 138.4375, 136.],
[138.4375, 136., 138.4375, 136., 138.4375]]],
dtype=FLOAT_TYPE)
expected_probe_wavefield = np.array(
[[[136., 136., 136., 136., 136.],
[136., 136., 136., 136., 136.],
[136., 136., 136., 136., 136.],
[136., 136., 136., 136., 136.],
[136., 136., 136., 136., 136.]],

[[136., 136., 136., 136., 136.],
[136., 136., 136., 136., 136.],
[136., 136., 136., 136., 136.],
[136., 136., 136., 136., 136.],
[136., 136., 136., 136., 136.]]],
dtype=FLOAT_TYPE)
np.testing.assert_array_equal(probe_wavefield.get(), expected_probe_wavefield,
err_msg="The probe wavefield has not been updated as expected")

Expand All @@ -708,7 +709,7 @@ def ob_update_ML_wavefield_tester(self, atomics=True):
setup
'''
addr, object_array, object_array_denominator, probe, exit_wave, probe_denominator = self.prepare_arrays()
object_array_wavefield = gpuarray.to_gpu(np.empty_like(object_array, dtype=FLOAT_TYPE))
object_array_wavefield = gpuarray.to_gpu(np.zeros_like(object_array, dtype=FLOAT_TYPE))
'''
test
'''
Expand Down Expand Up @@ -743,21 +744,21 @@ def ob_update_ML_wavefield_tester(self, atomics=True):
err_msg="The object array has not been updated as expected")

expected_object_array_wavefield = np.array(
[[[10., 22.4375, 20., 22.4375, 20., 12.4375, 0.],
[22.4375, 40., 42.4375, 40., 42.4375, 20., 2.4375],
[20., 42.4375, 40., 42.4375, 40., 22.4375, 0.],
[22.4375, 40., 42.4375, 40., 42.4375, 20., 2.4375],
[20., 42.4375, 40., 42.4375, 40., 22.4375, 0.],
[12.4375, 20., 22.4375, 20., 22.4375, 10., 2.4375],
[0., 2.4375, 0., 2.4375, 0., 2.4375, 0.]],

[[12.4375, 20., 22.4375, 20., 22.4375, 10., 2.4375],
[20., 42.4375, 40., 42.4375, 40., 22.4375, 0.],
[22.4375, 40., 42.4375, 40., 42.4375, 20., 2.4375],
[20., 42.4375, 40., 42.4375, 40., 22.4375, 0.],
[22.4375, 40., 42.4375, 40., 42.4375, 20., 2.4375],
[10., 22.4375, 20., 22.4375, 20., 12.4375, 0.],
[2.4375, 0., 2.4375, 0., 2.4375, 0., 2.4375]]],
[[[10., 20., 20., 20., 20., 10., 0.],
[20., 40., 40., 40., 40., 20., 0.],
[20., 40., 40., 40., 40., 20., 0.],
[20., 40., 40., 40., 40., 20., 0.],
[20., 40., 40., 40., 40., 20., 0.],
[10., 20., 20., 20., 20., 10., 0.],
[ 0., 0., 0., 0., 0., 0., 0.]],

[[10., 20., 20., 20., 20., 10., 0.],
[20., 40., 40., 40., 40., 20., 0.],
[20., 40., 40., 40., 40., 20., 0.],
[20., 40., 40., 40., 40., 20., 0.],
[20., 40., 40., 40., 40., 20., 0.],
[10., 20., 20., 20., 20., 10., 0.],
[ 0., 0., 0., 0., 0., 0., 0.]]],
dtype=FLOAT_TYPE)
np.testing.assert_array_equal(object_array_wavefield.get(), expected_object_array_wavefield,
err_msg="The object array wavefield has not been updated as expected")
Expand Down

0 comments on commit d8ac247

Please sign in to comment.