diff --git a/mcdc/algorithm.py b/mcdc/algorithm.py index bba90b5b..c35a64e0 100644 --- a/mcdc/algorithm.py +++ b/mcdc/algorithm.py @@ -2,10 +2,10 @@ @njit -def binary_search(val, grid, length=0): +def binary_search_with_length(val, grid, length): """ - Binary search that returns the bin index of the value `val` given grid `grid`. - Only search up to `length`-th element if `length` is given. + Binary search that returns the bin index of the value `val` given grid `grid` + Only search up to `length`-th element Some special cases: val < min(grid) --> -1 @@ -26,3 +26,12 @@ def binary_search(val, grid, length=0): else: right = mid - 1 return int(right) + + +@njit +def binary_search(val, grid): + """ + Binary search with full length of the given grid. + See binary_search_with _length + """ + return binary_search_with_length(val, grid, 0) diff --git a/mcdc/kernel.py b/mcdc/kernel.py index abddd69a..22222a8f 100644 --- a/mcdc/kernel.py +++ b/mcdc/kernel.py @@ -16,7 +16,7 @@ import mcdc.type_ as type_ from mcdc.adapt import toggle, for_cpu, for_gpu -from mcdc.algorithm import binary_search +from mcdc.algorithm import binary_search, binary_search_with_length from mcdc.constant import * from mcdc.loop import loop_source from mcdc.print_ import print_error, print_msg @@ -3244,7 +3244,7 @@ def get_microXS(type_, nuclide, E): @njit def get_XS(data, E, E_grid, NE): # Search XS energy bin index - idx = binary_search(E, E_grid, NE) + idx = binary_search_with_length(E, E_grid, NE) # Extrapolate if E is outside the given data if idx == -1: @@ -3314,7 +3314,7 @@ def sample_Eout(P_new, E_grid, NE, chi): xi = rng(P_new) # Determine bin index - idx = binary_search(xi, chi, NE) + idx = binary_search_with_length(xi, chi, NE) # Linear interpolation E1 = E_grid[idx]