Skip to content

Commit

Permalink
Bugfixes in wavefield kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
daurer committed Jan 24, 2025
1 parent ae0c87b commit 49f529d
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 6 deletions.
7 changes: 5 additions & 2 deletions ptypy/accelerate/base/engines/ML_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,12 +481,15 @@ def new_grad(self):
# local references
ob = self.engine.ob.S[oID].data
obg = ob_grad.S[oID].data
obf = ob_fln.S[oID].data
pr = self.engine.pr.S[pID].data
prg = pr_grad.S[pID].data
prf = pr_fln.S[pID].data
I = prep.I

# local references for wavefield precond
if self.engine.p.wavefield_precond:
obf = ob_fln.S[oID].data
prf = pr_fln.S[pID].data

# make propagated exit (to buffer)
AWK.build_aux_no_ex(aux, addr, ob, pr, add=False)

Expand Down
3 changes: 2 additions & 1 deletion ptypy/accelerate/cuda_common/ob_update2_ML_wavefield.cu
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ extern "C" __global__ void ob_update2_ML_wavefield(int pr_sh,
complex<MATH_TYPE> ex_val = ex_g[exidx];
complex<ACC_TYPE> add_val = cpr * ex_val * fac;
ob[idx] += add_val;
ACC_TYPE add_val2 = cpr * pr;
MATH_TYPE tmp = abs(pr);
ACC_TYPE add_val2 = tmp * tmp;
of[idx] += add_val2;
}
}
Expand Down
3 changes: 2 additions & 1 deletion ptypy/accelerate/cuda_common/ob_update_ML_wavefield.cu
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ extern "C"

atomicAdd(&obj[b * I + c], add_val);

MATH_TYPE add_val_m2 = conj(probe_val) * probe_val;
MATH_TYPE tmp = abs(probe_val);
MATH_TYPE add_val_m2 = tmp * tmp;
OUT_TYPE add_val2 = add_val_m2;

atomicAdd(&obj_fln[b * I + c], add_val2);
Expand Down
3 changes: 2 additions & 1 deletion ptypy/accelerate/cuda_common/pr_update2_ML_wavefield.cu
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ extern "C" __global__ void pr_update2_ML_wavefield(int pr_sh,
complex<MATH_TYPE> ex_val = ex_g[exidx];
complex<ACC_TYPE> add_val = cob * ex_val * fac;
pr[idx] += add_val;
ACC_TYPE add_val2 = cob * ob;
MATH_TYPE tmp = abs(ob);
ACC_TYPE add_val2 = tmp * tmp;
pf[idx] += add_val2;
}
}
Expand Down
3 changes: 2 additions & 1 deletion ptypy/accelerate/cuda_common/pr_update_ML_wavefield.cu
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ extern "C"

atomicAdd(&probe[b * F + c], add_val);

MATH_TYPE add_val_m2 = conj(obj_val) * obj_val;
MATH_TYPE tmp = abs(obj_val);
MATH_TYPE add_val_m2 = tmp * tmp;
OUT_TYPE add_val2 = add_val_m2;

atomicAdd(&probe_fln[b * F + c], add_val2);
Expand Down

0 comments on commit 49f529d

Please sign in to comment.