Skip to content

Commit

Permalink
revert binary_search for GPU-compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
ilhamv committed Aug 19, 2024
1 parent 617dae5 commit 8e1298f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
15 changes: 12 additions & 3 deletions mcdc/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@


@njit
def binary_search(val, grid, length=0):
def binary_search_with_length(val, grid, length):
"""
Binary search that returns the bin index of the value `val` given grid `grid`.
Only search up to `length`-th element if `length` is given.
Binary search that returns the bin index of the value `val` given grid `grid`
Only search up to `length`-th element
Some special cases:
val < min(grid) --> -1
Expand All @@ -26,3 +26,12 @@ def binary_search(val, grid, length=0):
else:
right = mid - 1
return int(right)


@njit
def binary_search(val, grid):
"""
Binary search with full length of the given grid.
See binary_search_with _length
"""
return binary_search_with_length(val, grid, 0)
6 changes: 3 additions & 3 deletions mcdc/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import mcdc.type_ as type_

from mcdc.adapt import toggle, for_cpu, for_gpu
from mcdc.algorithm import binary_search
from mcdc.algorithm import binary_search, binary_search_with_length
from mcdc.constant import *
from mcdc.loop import loop_source
from mcdc.print_ import print_error, print_msg
Expand Down Expand Up @@ -3244,7 +3244,7 @@ def get_microXS(type_, nuclide, E):
@njit
def get_XS(data, E, E_grid, NE):
# Search XS energy bin index
idx = binary_search(E, E_grid, NE)
idx = binary_search_with_length(E, E_grid, NE)

# Extrapolate if E is outside the given data
if idx == -1:
Expand Down Expand Up @@ -3314,7 +3314,7 @@ def sample_Eout(P_new, E_grid, NE, chi):
xi = rng(P_new)

# Determine bin index
idx = binary_search(xi, chi, NE)
idx = binary_search_with_length(xi, chi, NE)

# Linear interpolation
E1 = E_grid[idx]
Expand Down

0 comments on commit 8e1298f

Please sign in to comment.