Skip to content

Commit

Permalink
Better abs2 in wavefield kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
jfowkes committed Jan 25, 2025
1 parent 49f529d commit ec5bd9e
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 13 deletions.
4 changes: 2 additions & 2 deletions ptypy/accelerate/cuda_common/ob_update2_ML_wavefield.cu
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ extern "C" __global__ void ob_update2_ML_wavefield(int pr_sh,
complex<MATH_TYPE> ex_val = ex_g[exidx];
complex<ACC_TYPE> add_val = cpr * ex_val * fac;
ob[idx] += add_val;
MATH_TYPE tmp = abs(pr);
ACC_TYPE add_val2 = tmp * tmp;
complex<MATH_TYPE> abs2_val = cpr * pr;
ACC_TYPE add_val2 = abs2_val.real();
of[idx] += add_val2;
}
}
Expand Down
7 changes: 3 additions & 4 deletions ptypy/accelerate/cuda_common/ob_update_ML_wavefield.cu
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,11 @@ extern "C"

atomicAdd(&obj[b * I + c], add_val);

MATH_TYPE tmp = abs(probe_val);
MATH_TYPE add_val_m2 = tmp * tmp;
OUT_TYPE add_val2 = add_val_m2;
complex<MATH_TYPE> abs2_val = conj(probe_val) * probe_val;
OUT_TYPE add_val2 = abs2_val.real();

atomicAdd(&obj_fln[b * I + c], add_val2);
}
}
}
}
}
4 changes: 2 additions & 2 deletions ptypy/accelerate/cuda_common/pr_update2_ML_wavefield.cu
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ extern "C" __global__ void pr_update2_ML_wavefield(int pr_sh,
complex<MATH_TYPE> ex_val = ex_g[exidx];
complex<ACC_TYPE> add_val = cob * ex_val * fac;
pr[idx] += add_val;
MATH_TYPE tmp = abs(ob);
ACC_TYPE add_val2 = tmp * tmp;
complex<MATH_TYPE> abs2_val = cob * ob;
ACC_TYPE add_val2 = abs2_val.real();
pf[idx] += add_val2;
}
}
Expand Down
9 changes: 4 additions & 5 deletions ptypy/accelerate/cuda_common/pr_update_ML_wavefield.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/** pr_update_ML_wavefield.
#/** pr_update_ML_wavefield.
*
* Data types:
* - IN_TYPE: the data type for the inputs (float or double)
Expand Down Expand Up @@ -62,12 +62,11 @@ extern "C"

atomicAdd(&probe[b * F + c], add_val);

MATH_TYPE tmp = abs(obj_val);
MATH_TYPE add_val_m2 = tmp * tmp;
OUT_TYPE add_val2 = add_val_m2;
complex<MATH_TYPE> abs2_val = conj(obj_val) * obj_val;
OUT_TYPE add_val2 = abs2_val.real();

atomicAdd(&probe_fln[b * F + c], add_val2);
}
}
}
}
}

0 comments on commit ec5bd9e

Please sign in to comment.