Skip to content

Commit

Permalink
Added mode-based jit logic to avoid overflow-related warnings in rng …
Browse files Browse the repository at this point in the history
…logic
  • Loading branch information
braxtoncuneo committed Aug 29, 2023
1 parent 58c7d12 commit 24e9c66
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 43 deletions.
65 changes: 22 additions & 43 deletions mcdc/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,55 +83,34 @@ def sample_discrete(group, P):
# =============================================================================


@njit(numba.uint64(numba.uint64))
def bot_64(a):
half_mask = 0xFFFFFFFF
return a & half_mask

@njit(numba.uint64(numba.uint64, numba.uint64))
def wrapping_mul(a,b):
return a * b

@njit(numba.uint64(numba.uint64))
def top_64(a):
half_mask = 0xFFFFFFFF
return (a >> 32) & half_mask
@njit(numba.uint64(numba.uint64, numba.uint64))
def wrapping_add(a, b):
return a + b

def wrapping_mul_python(a,b):
a = numba.uint64(a)
b = numba.uint64(b)
with np.errstate(all='ignore'):
return a * b

@njit(numba.uint64(numba.uint64, numba.uint64))
def wrapping_mul_32_bit(a, b):
a_lo = bot_64(a)
a_hi = top_64(a)
b_lo = bot_64(b)
b_hi = top_64(b)
x = a_lo * b_lo
x_lo = bot_64(x)
x_hi = top_64(x)
y = bot_64(a_lo * b_hi)
z = bot_64(a_hi * b_lo)
top = bot_64(x_hi + y + z)
bot = x_lo
result = (top << 32) | bot
return result
def wrapping_add_python(a, b):
a = numba.uint64(a)
b = numba.uint64(b)
with np.errstate(all='ignore'):
return a + b


@njit(numba.uint64(numba.uint64, numba.uint64))
def wrapping_mul(a, b):
mask = numba.uint64(0xFFFFFFFFFFFFFFFF)
return (a * b) & mask
def adapt_rng(object_mode=False):
global wrapping_add, wrapping_mul
if object_mode:
wrapping_add = wrapping_add_python
wrapping_mul = wrapping_mul_python


@njit(numba.uint64(numba.uint64, numba.uint64))
def wrapping_add_32_bit(a, b):
a_lo = bot_64(a)
a_hi = top_64(a)
b_lo = bot_64(b)
b_hi = top_64(b)
x = a_lo + b_lo
x_lo = bot_64(x)
x_hi = top_64(x)
y = bot_64(a_hi + b_hi)
top = bot_64(x_hi + y)
bot = x_lo
result = (top << 32) | bot
return result


@njit(numba.uint64(numba.uint64, numba.uint64))
Expand Down Expand Up @@ -187,7 +166,7 @@ def rng_skip_ahead_(n, mcdc):

@njit(numba.uint64(numba.uint64))
def rng_(seed):
return (RNG_G * seed + RNG_C) & RNG_MOD_MASK
return wrapping_add(wrapping_mul(RNG_G,seed), RNG_C) & RNG_MOD_MASK


@njit
Expand Down
1 change: 1 addition & 0 deletions mcdc/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def prepare():
type_.make_type_tally(N_tally_scores, input_deck.tally)
type_.make_type_technique(N_particle, G, input_deck.technique)
type_.make_type_global(input_deck)
kernel.adapt_rng(nb.config.DISABLE_JIT)

# =========================================================================
# Make the global variable container
Expand Down

0 comments on commit 24e9c66

Please sign in to comment.