Skip to content

Commit

Permalink
[FRONTEND] remove dead code (#5507)
Browse files Browse the repository at this point in the history
  • Loading branch information
ptillet authored Dec 28, 2024
1 parent 75a4b7b commit 4c91b81
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 40 deletions.
2 changes: 1 addition & 1 deletion python/test/unit/runtime/test_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def walk_fn(op):
backend = triton.compiler.compiler.make_backend(target)
src = triton.compiler.compiler.ASTSource(
fn=kernel,
signature={kernel.arg_names[i]: kernel._type_of(kernel._key_of(arg))
signature={kernel.arg_names[i]: triton.runtime.jit.mangle_type(arg)
for i, arg in enumerate(args)},
constexprs={kernel.arg_names[i]: arg
for i, arg in enumerate(args)
Expand Down
4 changes: 2 additions & 2 deletions python/triton/runtime/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,7 +1015,7 @@ def _patch_lang(fn):
# TODO: wrap everything in triton tensors
def _implicit_cvt(arg):
if isinstance(arg, int):
ty = tl.str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg)))
ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg))
dtype = np.int32
if -2**31 <= arg < 2**31:
dtype = np.int32
Expand All @@ -1030,7 +1030,7 @@ def _implicit_cvt(arg):
handle = TensorHandle(np.array([arg], dtype=dtype), ty)
return tl.tensor(handle, ty)
if hasattr(arg, "data_ptr"):
ty = tl.str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg)))
ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg))
handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty)
return tl.tensor(handle, ty)
return arg
Expand Down
37 changes: 0 additions & 37 deletions python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,43 +449,6 @@ class JITFunction(KernelInterface[T]):
# cache_hook will always be called before compilation and compiled_hook after.
compiled_hook = None

@staticmethod
def _key_of(arg):
if hasattr(arg, "dtype"):
return arg.dtype
elif isinstance(arg, bool):
return "i1"
elif isinstance(arg, int):
if -(2**31) <= arg and arg <= 2**31 - 1:
return "i32"
elif 2**63 <= arg and arg <= 2**64 - 1:
return "u64"
else:
return "i64"
elif isinstance(arg, float):
return "fp32"
elif arg is None:
return None
else:
raise TypeError(f"Unsupported type {type(arg)} for {arg}")

@staticmethod
def _type_of(key, is_const=False):
# `None` is nullptr. Implicitly convert to *i8.
if key is None:
return "*i8"
elif isinstance(key, str):
return key

dtype_str = str(key).split(".")[-1]
dtype_str = type_canonicalisation_dict[dtype_str]
const_str = "*k" if is_const else "*"
return const_str + dtype_str

def _make_constants(self, constexpr_key):
constants = dict(zip(self.constexprs, constexpr_key))
return constants

def _call_hook(
self,
key,
Expand Down

0 comments on commit 4c91b81

Please sign in to comment.