Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[aot] Taichi Runtime Python language binding #8117

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
pre-commit-ci[bot] committed Jun 1, 2023

Verified

This commit was signed with the committer’s verified signature.
JLLeitschuh Jonathan Leitschuh
commit f7150bfa40c0d4ee3e532441f6a05b91d2861f39
119 changes: 80 additions & 39 deletions c_api/python/taichi_runtime/impl.py
Original file line number Diff line number Diff line change
@@ -82,10 +82,10 @@ def __del__(self):
self.destroy(quiet=True)

@staticmethod
def create(arch: Arch | List[Arch], *, device_index: int = 0) -> 'Runtime':
def create(arch: Arch | List[Arch], *, device_index: int = 0) -> "Runtime":
if isinstance(arch, Arch):
arch = [arch]
handle= TiRuntime(0)
handle = TiRuntime(0)
for a in arch:
try:
handle = ti_create_runtime(TiArch(a.value), ctypes.c_uint32(device_index))
@@ -112,7 +112,7 @@ def wait(self):
ti_wait(self._handle)
check_last_error()

def copy_memory_device_to_device(self, *, dst: 'Memory', src: 'Memory'):
def copy_memory_device_to_device(self, *, dst: "Memory", src: "Memory"):
dst2 = TiMemorySlice(
memory=dst._handle,
offset=0,
@@ -135,7 +135,9 @@ class MemoryUsage(Enum):


class Memory:
def __init__(self, runtime: Runtime, handle: TiMemory, *, size: int, host_access: bool, should_destroy: bool = True):
def __init__(
self, runtime: Runtime, handle: TiMemory, *, size: int, host_access: bool, should_destroy: bool = True
):
self._runtime = runtime
self._handle = handle
self._size = size
@@ -192,13 +194,18 @@ def read(self, dst: ByteString, *, force: bool = False):
ti_unmap_memory(self._runtime._handle, self._handle)
check_last_error()
elif force:
staging_buffer = Memory.allocate(self._runtime, size=self._size, host_access=True, usage=MemoryUsage.STORAGE)
staging_buffer = Memory.allocate(
self._runtime, size=self._size, host_access=True, usage=MemoryUsage.STORAGE
)
self._runtime.copy_memory_device_to_device(dst=staging_buffer, src=self)
self._runtime.wait()
staging_buffer.read(dst)
del staging_buffer
else:
raise TaichiRuntimeError(Error.NOT_SUPPORTED, "Memory.read() is not supported when `host_access` is False. Use `force=True` to force copying to host.")
raise TaichiRuntimeError(
Error.NOT_SUPPORTED,
"Memory.read() is not supported when `host_access` is False. Use `force=True` to force copying to host.",
)

def write(self, src: ByteString, *, force: bool = False):
assert isinstance(src, ByteString)
@@ -215,13 +222,18 @@ def write(self, src: ByteString, *, force: bool = False):
ti_unmap_memory(self._runtime._handle, self._handle)
check_last_error()
elif force:
staging_buffer = Memory.allocate(self._runtime, size=self._size, host_access=True, usage=MemoryUsage.STORAGE)
staging_buffer = Memory.allocate(
self._runtime, size=self._size, host_access=True, usage=MemoryUsage.STORAGE
)
staging_buffer.write(src)
self._runtime.copy_memory_device_to_device(dst=self, src=staging_buffer)
self._runtime.wait()
del staging_buffer
else:
raise TaichiRuntimeError(Error.NOT_SUPPORTED, "Memory.write() is not supported when `host_access` is False. Use `force=True` to force copying to host.")
raise TaichiRuntimeError(
Error.NOT_SUPPORTED,
"Memory.write() is not supported when `host_access` is False. Use `force=True` to force copying to host.",
)

@staticmethod
def from_bytes(runtime: Runtime, src: ByteString, *, host_access: bool = False):
@@ -266,17 +278,17 @@ class DataType(Enum):
}

_NP_DTYPE_TABLE: Dict[str, DataType] = {
'float16': DataType.F16,
'float32': DataType.F32,
'float64': DataType.F64,
'int8': DataType.I8,
'int16': DataType.I16,
'int32': DataType.I32,
'int64': DataType.I64,
'uint8': DataType.U8,
'uint16': DataType.U16,
'uint32': DataType.U32,
'uint64': DataType.U64,
"float16": DataType.F16,
"float32": DataType.F32,
"float64": DataType.F64,
"int8": DataType.I8,
"int16": DataType.I16,
"int32": DataType.I32,
"int64": DataType.I64,
"uint8": DataType.U8,
"uint16": DataType.U16,
"uint32": DataType.U32,
"uint64": DataType.U64,
}
_DTYPE_NP_TABLE: Dict[DataType, type] = {
DataType.F16: np.float16,
@@ -294,7 +306,9 @@ class DataType(Enum):


class NdArray:
def __init__(self, runtime: Runtime, memory: Memory, *, shape: Tuple[int], elem_shape: Tuple[int], elem_type: DataType):
def __init__(
self, runtime: Runtime, memory: Memory, *, shape: Tuple[int], elem_shape: Tuple[int], elem_type: DataType
):
self._runtime = runtime
self._memory = memory
self._shape = shape
@@ -321,16 +335,35 @@ def elem_type(self) -> DataType:
return self._elem_type

@staticmethod
def allocate(runtime: Runtime, elem_type: DataType, *, shape: Iterable[int], elem_shape: Iterable[int], host_access: bool = False):
size = reduce(lambda x, y: x * y, shape, 1) * reduce(lambda x, y: x * y, elem_shape, 1) * _DTYPE_SIZE_TABLE[elem_type.value]
def allocate(
runtime: Runtime,
elem_type: DataType,
*,
shape: Iterable[int],
elem_shape: Iterable[int],
host_access: bool = False,
):
size = (
reduce(lambda x, y: x * y, shape, 1)
* reduce(lambda x, y: x * y, elem_shape, 1)
* _DTYPE_SIZE_TABLE[elem_type.value]
)
memory = Memory.allocate(runtime, size=size, host_access=host_access, usage=MemoryUsage.STORAGE)
return NdArray(runtime, memory, shape=tuple(shape), elem_shape=tuple(elem_shape), elem_type=elem_type)

def free(self):
self._memory.free()

@staticmethod
def from_numpy(runtime: Runtime, arr: npt.NDArray[Any], *, shape: Optional[Iterable[int]] = None, elem_shape: Optional[Iterable[int]] = None, elem_type: Optional[DataType] = None, host_access=False):
def from_numpy(
runtime: Runtime,
arr: npt.NDArray[Any],
*,
shape: Optional[Iterable[int]] = None,
elem_shape: Optional[Iterable[int]] = None,
elem_type: Optional[DataType] = None,
host_access=False,
):
assert isinstance(arr, np.ndarray)

if elem_type is None:
@@ -344,10 +377,12 @@ def from_numpy(runtime: Runtime, arr: npt.NDArray[Any], *, shape: Optional[Itera
elem_shape2 = tuple(elem_shape)
assert len(elem_shape2) <= len(arr.shape)
for ielem_shape, ishape in enumerate(range(len(arr.shape) - len(elem_shape2), len(arr.shape))):
assert arr.shape[ishape] == elem_shape2[ielem_shape], f"arr.shape[{ishape}] ({arr.shape[ishape]}) != elem_shape2[{ielem_shape}] ({elem_shape2[ielem_shape]})"
assert (
arr.shape[ishape] == elem_shape2[ielem_shape]
), f"arr.shape[{ishape}] ({arr.shape[ishape]}) != elem_shape2[{ielem_shape}] ({elem_shape2[ielem_shape]})"

if shape is None:
shape2 = arr.shape[:len(arr.shape) - len(elem_shape2)]
shape2 = arr.shape[: len(arr.shape) - len(elem_shape2)]
else:
shape2 = tuple(shape)
assert len(shape2) <= len(arr.shape)
@@ -358,7 +393,9 @@ def from_numpy(runtime: Runtime, arr: npt.NDArray[Any], *, shape: Optional[Itera
return NdArray(runtime, memory, shape=shape2, elem_shape=elem_shape2, elem_type=elem_type2)

def to_numpy(self) -> npt.NDArray[Any]:
out = np.frombuffer(self.memory.to_bytes(), dtype=_DTYPE_NP_TABLE[self.elem_type]).reshape(self.shape + self.elem_shape)
out = np.frombuffer(self.memory.to_bytes(), dtype=_DTYPE_NP_TABLE[self.elem_type]).reshape(
self.shape + self.elem_shape
)
return out

def into_numpy(self) -> npt.NDArray[Any]:
@@ -387,31 +424,35 @@ def __init__(self, value: Any, *, ty: Optional[ArgumentType] = None) -> None:
if ty == ArgumentType.I32:
assert isinstance(value, int)
value = TiArgumentValue(
i32 = ctypes.c_int32(value),
i32=ctypes.c_int32(value),
)
elif ty == ArgumentType.F32:
assert isinstance(value, float)
value = TiArgumentValue(
f32 = ctypes.c_float(value),
f32=ctypes.c_float(value),
)
elif ty == ArgumentType.NDARRAY:
assert isinstance(value, NdArray)
shape = TiNdShape(
dim_count = ctypes.c_uint32(len(value._shape)),
dims = (ctypes.c_uint32 * 16)(*[ctypes.c_uint32(x) for x in value._shape] + [0] * (16 - len(value._shape))),
dim_count=ctypes.c_uint32(len(value._shape)),
dims=(ctypes.c_uint32 * 16)(
*[ctypes.c_uint32(x) for x in value._shape] + [0] * (16 - len(value._shape))
),
)
elem_shape = TiNdShape(
dim_count = ctypes.c_uint32(len(value._elem_shape)),
dims = (ctypes.c_uint32 * 16)(*[ctypes.c_uint32(x) for x in value._elem_shape] + [0] * (16 - len(value._elem_shape))),
dim_count=ctypes.c_uint32(len(value._elem_shape)),
dims=(ctypes.c_uint32 * 16)(
*[ctypes.c_uint32(x) for x in value._elem_shape] + [0] * (16 - len(value._elem_shape))
),
)
x = TiNdArray(
memory = value._memory._handle,
shape = shape,
elem_shape = elem_shape,
elem_type = value._elem_type.value,
memory=value._memory._handle,
shape=shape,
elem_shape=elem_shape,
elem_type=value._elem_type.value,
)
value = TiArgumentValue(
ndarray = x,
ndarray=x,
)
else:
raise TaichiRuntimeError(Error.NOT_SUPPORTED, f"ArgumentType.{ty.name} is not supported.")
@@ -421,7 +462,7 @@ def __init__(self, value: Any, *, ty: Optional[ArgumentType] = None) -> None:


class Kernel:
def __init__(self, aot_module: 'AotModule', name: str, handle: TiAotModule):
def __init__(self, aot_module: "AotModule", name: str, handle: TiAotModule):
self._aot_module = aot_module
self._name = name
self._handle = handle
@@ -451,7 +492,7 @@ def load(runtime: Runtime, path: str):
handle = ti_load_aot_module(runtime._handle, _p(path.encode("ascii")))
check_last_error()
return AotModule(runtime, handle)

@staticmethod
def create(runtime: Runtime, tcm: bytes):
handle = ti_create_aot_module(runtime._handle, _p(tcm), ctypes.c_uint64(len(tcm)))
7 changes: 4 additions & 3 deletions c_api/python/taichi_runtime/sys/_lib.py
Original file line number Diff line number Diff line change
@@ -7,9 +7,11 @@
"lib/libtaichi_c_api.dylib",
]


def find_taichi_c_api_in_wheel():
try:
import taichi as ti

for candidate_name in CANDIDATE_NAMES:
try:
taichi_c_api_path = list(ti.__path__)[0] + "/_lib/c_api/" + candidate_name
@@ -21,6 +23,7 @@ def find_taichi_c_api_in_wheel():
pass
return None


def load_taichi_c_api() -> ctypes.CDLL:
import ctypes.util as ctypes_util
from os import environ
@@ -38,9 +41,7 @@ def load_taichi_c_api() -> ctypes.CDLL:
break

if path is None:
raise RuntimeError(
"Cannot find taichi_c_api. Please set TAICHI_C_API_INSTALL_DIR environment variable."
)
raise RuntimeError("Cannot find taichi_c_api. Please set TAICHI_C_API_INSTALL_DIR environment variable.")

print(f"Found taichi_c_api at {path}")
out = ctypes.CDLL(path, ctypes.RTLD_LOCAL)
Loading