Skip to content

Commit

Permalink
replace njit with cuda.jit
Browse files Browse the repository at this point in the history
  • Loading branch information
ilhamv committed Aug 15, 2024
1 parent ccbfbe8 commit d8a3e00
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions mcdc/code_factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from numba import njit
from numba import cuda, njit

import mcdc.local as local
import mcdc.type_ as type_
Expand Down Expand Up @@ -36,7 +36,7 @@ def local_array(dtype, size, target):
def cpu():
return np.zeros(1, dtype=struct)[0]

@njit
@cuda.jit(device=True)
def gpu():
return cuda.local.array(1, dtype=struct)[0]

Expand All @@ -49,7 +49,7 @@ def local_object(dtype, target):
def cpu():
return np.zeros(1, dtype=dtype)[0]

@njit
@cuda.jit(device=True)
def gpu():
return cuda.local.array(1, dtype=dtype)[0]

Expand Down
2 changes: 1 addition & 1 deletion mcdc/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def prepare():
type_.make_type_global(input_deck)
kernel.adapt_rng(nb.config.DISABLE_JIT)

input_deck.setting['target'] = target
input_deck.setting["target"] = target
code_factory.make_locals(input_deck)

# =========================================================================
Expand Down

0 comments on commit d8a3e00

Please sign in to comment.