diff --git a/mcdc/kernel.py b/mcdc/kernel.py index a317ad5a..a3c0e5ea 100644 --- a/mcdc/kernel.py +++ b/mcdc/kernel.py @@ -83,55 +83,34 @@ def sample_discrete(group, P): # ============================================================================= -@njit(numba.uint64(numba.uint64)) -def bot_64(a): - half_mask = 0xFFFFFFFF - return a & half_mask - +@njit(numba.uint64(numba.uint64, numba.uint64)) +def wrapping_mul(a,b): + return a * b -@njit(numba.uint64(numba.uint64)) -def top_64(a): - half_mask = 0xFFFFFFFF - return (a >> 32) & half_mask +@njit(numba.uint64(numba.uint64, numba.uint64)) +def wrapping_add(a, b): + return a + b +def wrapping_mul_python(a,b): + a = numba.uint64(a) + b = numba.uint64(b) + with np.errstate(all='ignore'): + return a * b -@njit(numba.uint64(numba.uint64, numba.uint64)) -def wrapping_mul_32_bit(a, b): - a_lo = bot_64(a) - a_hi = top_64(a) - b_lo = bot_64(b) - b_hi = top_64(b) - x = a_lo * b_lo - x_lo = bot_64(x) - x_hi = top_64(x) - y = bot_64(a_lo * b_hi) - z = bot_64(a_hi * b_lo) - top = bot_64(x_hi + y + z) - bot = x_lo - result = (top << 32) | bot - return result +def wrapping_add_python(a, b): + a = numba.uint64(a) + b = numba.uint64(b) + with np.errstate(all='ignore'): + return a + b -@njit(numba.uint64(numba.uint64, numba.uint64)) -def wrapping_mul(a, b): - mask = numba.uint64(0xFFFFFFFFFFFFFFFF) - return (a * b) & mask +def adapt_rng(object_mode=False): + global wrapping_add, wrapping_mul + if object_mode: + wrapping_add = wrapping_add_python + wrapping_mul = wrapping_mul_python -@njit(numba.uint64(numba.uint64, numba.uint64)) -def wrapping_add_32_bit(a, b): - a_lo = bot_64(a) - a_hi = top_64(a) - b_lo = bot_64(b) - b_hi = top_64(b) - x = a_lo + b_lo - x_lo = bot_64(x) - x_hi = top_64(x) - y = bot_64(a_hi + b_hi) - top = bot_64(x_hi + y) - bot = x_lo - result = (top << 32) | bot - return result @njit(numba.uint64(numba.uint64, numba.uint64)) @@ -187,7 +166,7 @@ def rng_skip_ahead_(n, mcdc): @njit(numba.uint64(numba.uint64)) def rng_(seed): - return (RNG_G * seed + RNG_C) & RNG_MOD_MASK + return wrapping_add(wrapping_mul(RNG_G,seed), RNG_C) & RNG_MOD_MASK @njit diff --git a/mcdc/main.py b/mcdc/main.py index 14a0e795..19106875 100644 --- a/mcdc/main.py +++ b/mcdc/main.py @@ -168,6 +168,7 @@ def prepare(): type_.make_type_tally(N_tally_scores, input_deck.tally) type_.make_type_technique(N_particle, G, input_deck.technique) type_.make_type_global(input_deck) + kernel.adapt_rng(nb.config.DISABLE_JIT) # ========================================================================= # Make the global variable container