From ec5bd9ef8d7dbae7884a7bdcf75dc020d30a19eb Mon Sep 17 00:00:00 2001 From: Jari Date: Sat, 25 Jan 2025 10:17:14 +0000 Subject: [PATCH] Better abs2 in wavefield kernels --- ptypy/accelerate/cuda_common/ob_update2_ML_wavefield.cu | 4 ++-- ptypy/accelerate/cuda_common/ob_update_ML_wavefield.cu | 7 +++---- ptypy/accelerate/cuda_common/pr_update2_ML_wavefield.cu | 4 ++-- ptypy/accelerate/cuda_common/pr_update_ML_wavefield.cu | 9 ++++----- 4 files changed, 11 insertions(+), 13 deletions(-) diff --git a/ptypy/accelerate/cuda_common/ob_update2_ML_wavefield.cu b/ptypy/accelerate/cuda_common/ob_update2_ML_wavefield.cu index ed3ad2d6..649ce4a8 100644 --- a/ptypy/accelerate/cuda_common/ob_update2_ML_wavefield.cu +++ b/ptypy/accelerate/cuda_common/ob_update2_ML_wavefield.cu @@ -105,8 +105,8 @@ extern "C" __global__ void ob_update2_ML_wavefield(int pr_sh, complex ex_val = ex_g[exidx]; complex add_val = cpr * ex_val * fac; ob[idx] += add_val; - MATH_TYPE tmp = abs(pr); - ACC_TYPE add_val2 = tmp * tmp; + complex abs2_val = cpr * pr; + ACC_TYPE add_val2 = abs2_val.real(); of[idx] += add_val2; } } diff --git a/ptypy/accelerate/cuda_common/ob_update_ML_wavefield.cu b/ptypy/accelerate/cuda_common/ob_update_ML_wavefield.cu index 19e76795..80eba3cd 100644 --- a/ptypy/accelerate/cuda_common/ob_update_ML_wavefield.cu +++ b/ptypy/accelerate/cuda_common/ob_update_ML_wavefield.cu @@ -62,12 +62,11 @@ extern "C" atomicAdd(&obj[b * I + c], add_val); - MATH_TYPE tmp = abs(probe_val); - MATH_TYPE add_val_m2 = tmp * tmp; - OUT_TYPE add_val2 = add_val_m2; + complex abs2_val = conj(probe_val) * probe_val; + OUT_TYPE add_val2 = abs2_val.real(); atomicAdd(&obj_fln[b * I + c], add_val2); } } } -} \ No newline at end of file +} diff --git a/ptypy/accelerate/cuda_common/pr_update2_ML_wavefield.cu b/ptypy/accelerate/cuda_common/pr_update2_ML_wavefield.cu index 8271daa0..567d928b 100644 --- a/ptypy/accelerate/cuda_common/pr_update2_ML_wavefield.cu +++ b/ptypy/accelerate/cuda_common/pr_update2_ML_wavefield.cu @@ -105,8 +105,8 @@ extern "C" __global__ void pr_update2_ML_wavefield(int pr_sh, complex ex_val = ex_g[exidx]; complex add_val = cob * ex_val * fac; pr[idx] += add_val; - MATH_TYPE tmp = abs(ob); - ACC_TYPE add_val2 = tmp * tmp; + complex abs2_val = cob * ob; + ACC_TYPE add_val2 = abs2_val.real(); pf[idx] += add_val2; } } diff --git a/ptypy/accelerate/cuda_common/pr_update_ML_wavefield.cu b/ptypy/accelerate/cuda_common/pr_update_ML_wavefield.cu index bfe1c74c..2cd22d32 100644 --- a/ptypy/accelerate/cuda_common/pr_update_ML_wavefield.cu +++ b/ptypy/accelerate/cuda_common/pr_update_ML_wavefield.cu @@ -1,4 +1,4 @@ -/** pr_update_ML_wavefield. +#/** pr_update_ML_wavefield. * * Data types: * - IN_TYPE: the data type for the inputs (float or double) @@ -62,12 +62,11 @@ extern "C" atomicAdd(&probe[b * F + c], add_val); - MATH_TYPE tmp = abs(obj_val); - MATH_TYPE add_val_m2 = tmp * tmp; - OUT_TYPE add_val2 = add_val_m2; + complex abs2_val = conj(obj_val) * obj_val; + OUT_TYPE add_val2 = abs2_val.real(); atomicAdd(&probe_fln[b * F + c], add_val2); } } } -} \ No newline at end of file +}