diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 95605175ff13..ed68ecb9a023 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -7168,3 +7168,14 @@ def aliasing_kernel(buffer, buffer2): buffer = torch.zeros(1, device=device) aliasing_kernel[(1, )](buffer, buffer) assert buffer[0] == 1 + + +@pytest.mark.interpreter +@pytest.mark.parametrize("dtype", map(tl.dtype, tl.dtype.SINT_TYPES + tl.dtype.UINT_TYPES + tl.dtype.STANDARD_FP_TYPES)) +def test_dtypes(device, dtype): + + @triton.jit + def dtype_kernel(dtype): + tensor = tl.zeros((1, ), dtype) + + dtype_kernel[(1, )](dtype) diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py index 8ce79df0457f..8f099b8e0f9f 100644 --- a/python/triton/runtime/interpreter.py +++ b/python/triton/runtime/interpreter.py @@ -306,6 +306,9 @@ def get_float_ty(self): def get_double_ty(self): return tl.float64 + def get_int1_ty(self): + return tl.int1 + def get_int8_ty(self): return tl.int8