From 24ced3739a75ee482be898707cb69d569808a9aa Mon Sep 17 00:00:00 2001 From: George Hotz Date: Fri, 24 Nov 2023 12:01:00 -0800 Subject: [PATCH] test dtypes --- .github/workflows/test.yml | 6 +- import_from_tinygrad.py | 2 +- mnist.py | 2 +- sz.py | 4 +- teenygrad/__init__.py | 1 + teenygrad/helpers.py | 30 ++++-- teenygrad/lazy.py | 34 ++++--- test/test_dtype.py | 194 +++++++++++++++++++++++++++++++++++++ 8 files changed, 247 insertions(+), 26 deletions(-) create mode 100644 teenygrad/__init__.py create mode 100644 test/test_dtype.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 84f9db3..d8f62f9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -24,5 +24,7 @@ jobs: run: PYTHONPATH="." python mnist.py - name: Install torch for testing run: pip install torch --extra-index-url https://download.pytorch.org/whl/cpu - - name: Test Ops - run: PYTHONPATH="." python test/test_ops.py + - name: Test Ops and DTypes + run: | + PYTHONPATH="." python test/test_ops.py + PYTHONPATH="." python test/test_dtypes.py diff --git a/import_from_tinygrad.py b/import_from_tinygrad.py index 925c964..89ee29c 100755 --- a/import_from_tinygrad.py +++ b/import_from_tinygrad.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 import pathlib -FILES = ["tensor.py", "mlops.py", "nn/optim.py", "../test/test_ops.py"] +FILES = ["tensor.py", "mlops.py", "nn/optim.py", "../test/test_ops.py", "../test/test_dtype.py"] src = pathlib.Path("../tinygrad/tinygrad") dest = pathlib.Path("teenygrad") diff --git a/mnist.py b/mnist.py index f37d033..0dbb6ab 100755 --- a/mnist.py +++ b/mnist.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 import numpy as np -from teenygrad.tensor import Tensor +from teenygrad import Tensor from tqdm import trange import gzip, os diff --git a/sz.py b/sz.py index 4ea8e89..a17c197 100755 --- a/sz.py +++ b/sz.py @@ -24,4 +24,6 @@ for dir_name, group in itertools.groupby(sorted([(x[0].rsplit("/", 1)[0], x[1]) for x in table]), key=lambda x:x[0]): print(f"{dir_name:30s} : {sum([x[1] for x in group]):6d}") - print(f"\ntotal line count: {sum([x[1] for x in table])}") + total_line_count = sum([x[1] for x in table]) + print(f"\ntotal line count: {total_line_count}") + assert total_line_count < 1000, "TEENYGRAD IS FUCKING TEENY IF YOU GO OVER 1000 LINES IN TEENYGRAD MIGHT AS WELL USE TINYGRAD U FAT FUCK" diff --git a/teenygrad/__init__.py b/teenygrad/__init__.py new file mode 100644 index 0000000..26cd867 --- /dev/null +++ b/teenygrad/__init__.py @@ -0,0 +1 @@ +from teenygrad.tensor import Tensor # noqa: F401 \ No newline at end of file diff --git a/teenygrad/helpers.py b/teenygrad/helpers.py index ab209d1..e11d219 100644 --- a/teenygrad/helpers.py +++ b/teenygrad/helpers.py @@ -1,9 +1,10 @@ -from typing import Union, Tuple, Iterator, NamedTuple, Optional, Final, Any -import os, functools +from typing import Union, Tuple, Iterator, Optional, Final, Any +import os, functools, platform import numpy as np from math import prod # noqa: F401 # pylint:disable=unused-import from dataclasses import dataclass +OSX = platform.system() == "Darwin" def dedup(x): return list(dict.fromkeys(x)) # retains list orderi def argfix(*x): return tuple(x[0]) if x and x[0].__class__ in (tuple, list) else x def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x @@ -28,16 +29,33 @@ class DType: def __repr__(self): return f"dtypes.{self.name}" class dtypes: + @staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool + def is_int(x: DType)-> bool: return x in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64) @staticmethod - def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name] + def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes.float64) + @staticmethod + def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64) @staticmethod - def is_float(x: DType) -> bool: return x in (dtypes.float32, dtypes.float64) + def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name] + bool: Final[DType] = DType(0, 1, "bool", np.bool_) + float16: Final[DType] = DType(9, 2, "half", np.float16) + half = float16 float32: Final[DType] = DType(10, 4, "float", np.float32) + float = float32 float64: Final[DType] = DType(11, 8, "double", np.float64) + double = float64 + int8: Final[DType] = DType(1, 1, "char", np.int8) + int16: Final[DType] = DType(3, 2, "short", np.int16) int32: Final[DType] = DType(5, 4, "int", np.int32) int64: Final[DType] = DType(7, 8, "long", np.int64) uint8: Final[DType] = DType(2, 1, "unsigned char", np.uint8) - bool: Final[DType] = DType(0, 1, "bool", np.bool_) + uint16: Final[DType] = DType(4, 2, "unsigned short", np.uint16) + uint32: Final[DType] = DType(6, 4, "unsigned int", np.uint32) + uint64: Final[DType] = DType(8, 8, "unsigned long", np.uint64) + + # NOTE: bfloat16 isn't supported in numpy + bfloat16: Final[DType] = DType(9, 2, "__bf16", None) + DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and not v.__class__ == staticmethod} -ImageDType, IMAGE = None, 0 # junk to remove +PtrDType, ImageDType, IMAGE = None, None, 0 # junk to remove diff --git a/teenygrad/lazy.py b/teenygrad/lazy.py index c4eeb5e..84cc4e6 100644 --- a/teenygrad/lazy.py +++ b/teenygrad/lazy.py @@ -1,5 +1,5 @@ from __future__ import annotations -from teenygrad.helpers import DType, dtypes +from teenygrad.helpers import DType, dtypes, DEBUG from teenygrad.ops import UnaryOps, BinaryOps, ReduceOps, TernaryOps, LoadOps import numpy as np @@ -18,6 +18,7 @@ def dtype(self): return dtypes.from_np(self._np.dtype) def realized(self): return RawCPUBuffer(self._np) @property def shape(self): return self._np.shape + def __repr__(self): return f"" def schedule(self, seen=None): return [] def is_unrealized_const(self): return False @@ -35,27 +36,30 @@ def loadop(op, shape, dtype, device, arg=None, src=None) -> LazyBuffer: def contiguous(x): return x def const(self, x) -> LazyBuffer: return LazyBuffer(np.full_like(self._np, x)) - def cast(self, dtype:DType, bitcast:bool=False): return LazyBuffer(self._np.astype(dtype.np)) + def cast(self, dtype:DType, bitcast:bool=False): return LazyBuffer(self._np.view(dtype.np) if bitcast else self._np.astype(dtype.np)) def e(self, op, *srcs:LazyBuffer): - if op == UnaryOps.NEG: return LazyBuffer(-self._np) - elif op == UnaryOps.EXP2: return LazyBuffer(np.exp2(self._np)) - elif op == UnaryOps.LOG2: return LazyBuffer(np.log2(self._np)) - elif op == UnaryOps.SIN: return LazyBuffer(np.sin(self._np)) - elif op == UnaryOps.SQRT: return LazyBuffer(np.sqrt(self._np)) - elif op == BinaryOps.ADD: return LazyBuffer(self._np + srcs[0]._np) - elif op == BinaryOps.SUB: return LazyBuffer(self._np - srcs[0]._np) - elif op == BinaryOps.MUL: return LazyBuffer(self._np * srcs[0]._np) - elif op == BinaryOps.DIV: return LazyBuffer((self._np / srcs[0]._np).astype(max(self.dtype, srcs[0].dtype).np)) - elif op == BinaryOps.MAX: return LazyBuffer(np.maximum(self._np, srcs[0]._np)) - elif op == BinaryOps.CMPLT: return LazyBuffer(self._np < srcs[0]._np) - elif op == TernaryOps.WHERE: return LazyBuffer(np.where(self._np, srcs[0]._np, srcs[1]._np)) + if DEBUG >= 1: print(op, self, srcs) + if op == UnaryOps.NEG: ret = -self._np + elif op == UnaryOps.EXP2: ret = np.exp2(self._np) + elif op == UnaryOps.LOG2: ret = np.log2(self._np) + elif op == UnaryOps.SIN: ret = np.sin(self._np) + elif op == UnaryOps.SQRT: ret = np.sqrt(self._np) + elif op == BinaryOps.ADD: ret = self._np + srcs[0]._np + elif op == BinaryOps.SUB: ret = self._np - srcs[0]._np + elif op == BinaryOps.MUL: ret = self._np * srcs[0]._np + elif op == BinaryOps.DIV: ret = self._np / srcs[0]._np + elif op == BinaryOps.MAX: ret = np.maximum(self._np, srcs[0]._np) + elif op == BinaryOps.CMPLT: ret = self._np < srcs[0]._np + elif op == TernaryOps.WHERE: ret = np.where(self._np, srcs[0]._np, srcs[1]._np) else: raise NotImplementedError(op) + return LazyBuffer(ret.astype(self.dtype.np if len(srcs) == 0 else max(self.dtype, *[x.dtype for x in srcs]).np, copy=False)) def r(self, op, new_shape): + if DEBUG >= 1: print(op, self, new_shape) assert len(self.shape) == len(new_shape), "reduce shapes must have same dimensions" axis = tuple(i for i,(a,b) in enumerate(zip(self.shape, new_shape)) if a != b) - if op == ReduceOps.SUM: return LazyBuffer(self._np.sum(axis, keepdims=True)) + if op == ReduceOps.SUM: return LazyBuffer(self._np.sum(axis, dtype=self._np.dtype, keepdims=True)) elif op == ReduceOps.MAX: return LazyBuffer(self._np.max(axis, keepdims=True)) else: raise NotImplementedError(op) diff --git a/test/test_dtype.py b/test/test_dtype.py new file mode 100644 index 0000000..a79d31a --- /dev/null +++ b/test/test_dtype.py @@ -0,0 +1,194 @@ +import unittest +import numpy as np +from teenygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType, OSX +from teenygrad.ops import Device +from teenygrad.tensor import Tensor, dtypes +from typing import Any, List + +def is_dtype_supported(dtype: DType): + # for GPU, cl_khr_fp16 isn't supported (except now we don't need it!) + # for LLVM, it segfaults because it can't link to the casting function + if dtype == dtypes.half: return not (CI and Device.DEFAULT in ["GPU", "LLVM"]) and Device.DEFAULT != "WEBGPU" and getenv("CUDACPU") != 1 + if dtype == dtypes.bfloat16: return False # numpy doesn't support bf16, tested separately in TestBFloat16DType + if dtype == dtypes.float64: return Device.DEFAULT not in ["WEBGPU", "METAL"] and (not OSX and Device.DEFAULT == "GPU") + if dtype in [dtypes.int8, dtypes.uint8]: return Device.DEFAULT not in ["WEBGPU"] + if dtype in [dtypes.int16, dtypes.uint16]: return Device.DEFAULT not in ["WEBGPU", "TORCH"] + if dtype == dtypes.uint32: return Device.DEFAULT not in ["TORCH"] + if dtype in [dtypes.int64, dtypes.uint64]: return Device.DEFAULT not in ["WEBGPU", "TORCH"] + if dtype == dtypes.bool: + # host-shareablity is a requirement for storage buffers, but 'bool' type is not host-shareable + if Device.DEFAULT == "WEBGPU": return False + return True + +def get_available_cast_dtypes(dtype: DType) -> List[DType]: return [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_")] # dont cast internal dtypes + +def _test_to_np(a:Tensor, np_dtype, target): + if DEBUG >= 2: print(a) + na = a.numpy() + if DEBUG >= 2: print(na, na.dtype, a.lazydata.realized) + try: + assert na.dtype == np_dtype + np.testing.assert_allclose(na, target) + except AssertionError as e: + raise AssertionError(f"\ntensor {a.numpy()} does not match target {target} with np_dtype {np_dtype}") from e + +def _assert_eq(tensor:Tensor, target_dtype:DType, target): + if DEBUG >= 2: print(tensor.numpy()) + try: + assert tensor.dtype == target_dtype + np.testing.assert_allclose(tensor.numpy(), target) + except AssertionError as e: + raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e + +def _test_op(fxn, target_dtype:DType, target): _assert_eq(fxn(), target_dtype, target) +def _test_cast(a:Tensor, target_dtype:DType): _test_op(lambda: a.cast(target_dtype), target_dtype, a.numpy().astype(target_dtype.np).tolist()) +def _test_bitcast(a:Tensor, target_dtype:DType, target): _test_op(lambda: a.bitcast(target_dtype), target_dtype, target) + +class TestDType(unittest.TestCase): + DTYPE: Any = None + DATA: Any = None + @classmethod + def setUpClass(cls): + if not is_dtype_supported(cls.DTYPE): raise unittest.SkipTest("dtype not supported") + cls.DATA = np.random.randint(0, 100, size=10, dtype=cls.DTYPE.np).tolist() if dtypes.is_int(cls.DTYPE) else np.random.choice([True, False], size=10).tolist() if cls.DTYPE == dtypes.bool else np.random.uniform(0, 1, size=10).tolist() + def setUp(self): + if self.DTYPE is None: raise unittest.SkipTest("base class") + + def test_to_np(self): _test_to_np(Tensor(self.DATA, dtype=self.DTYPE), self.DTYPE.np, np.array(self.DATA, dtype=self.DTYPE.np)) + + def test_casts_to(self): list(map( + lambda dtype: _test_cast(Tensor(self.DATA, dtype=dtype), self.DTYPE), + get_available_cast_dtypes(self.DTYPE) + )) + def test_casts_from(self): list(map( + lambda dtype: _test_cast(Tensor(self.DATA, dtype=self.DTYPE), dtype), + get_available_cast_dtypes(self.DTYPE) + )) + + def test_same_size_ops(self): + def get_target_dtype(dtype): + if any([dtypes.is_float(dtype), dtypes.is_float(self.DTYPE)]): return max([dtype, self.DTYPE], key=lambda x: x.priority) + return dtype if dtypes.is_unsigned(dtype) else self.DTYPE + list(map( + lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype, target_dtype=get_target_dtype(dtype)) if dtype.itemsize == self.DTYPE.itemsize else None, + get_available_cast_dtypes(self.DTYPE) + )) + def test_upcast_ops(self): list(map( + lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype) if dtype.itemsize > self.DTYPE.itemsize else None, + get_available_cast_dtypes(self.DTYPE) + )) + def test_upcast_to_ops(self): + list(map( + lambda dtype: _test_ops(a_dtype=dtype, b_dtype=self.DTYPE) if dtype.itemsize < self.DTYPE.itemsize else None, + get_available_cast_dtypes(self.DTYPE) + )) + +def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None): + if not is_dtype_supported(a_dtype) or not is_dtype_supported(b_dtype): return + if a_dtype == dtypes.bool or b_dtype == dtypes.bool: return + target_dtype = target_dtype or (max([a_dtype, b_dtype], key=lambda x: x.priority) if a_dtype.priority != b_dtype.priority else max([a_dtype, b_dtype], key=lambda x: x.itemsize)) + _assert_eq(Tensor([1,2,3,4], dtype=a_dtype)+Tensor([1,2,3,4], dtype=b_dtype), target_dtype, [2,4,6,8]) + _assert_eq(Tensor([1,2,3,4], dtype=a_dtype)*Tensor([1,2,3,4], dtype=b_dtype), target_dtype, [1,4,9,16]) + _assert_eq(Tensor([[1,2],[3,4]], dtype=a_dtype)@Tensor.eye(2, dtype=b_dtype), target_dtype, [[1,2],[3,4]]) + _assert_eq(Tensor([1,1,1,1], dtype=a_dtype)+Tensor.ones((4,4), dtype=b_dtype), target_dtype, 2*Tensor.ones(4,4).numpy()) + +class TestBFloat16DType(unittest.TestCase): + def setUp(self): + if not is_dtype_supported(dtypes.bfloat16): raise unittest.SkipTest("bfloat16 not supported") + def test_bf16_to_float(self): + with self.assertRaises(AssertionError): + _test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32, [100000]) + + def test_float_to_bf16(self): + with self.assertRaises(AssertionError): + _test_cast(Tensor([100000], dtype=dtypes.float32), dtypes.bfloat16, [100000]) + + # torch.tensor([10000, -1, -1000, -10000, 20]).type(torch.bfloat16) + + def test_bf16(self): + t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.bfloat16) + t.realize() + back = t.cast(dtypes.float32) + assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20) + + def test_bf16_disk_write_read(self): + from extra.utils import temp + t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.float32) + t.to(f"disk:{temp('f32')}").realize() + + # hack to "cast" f32 -> bf16 + dat = open(temp('f32'), "rb").read() + adat = b''.join([dat[i+2:i+4] for i in range(0, len(dat), 4)]) + with open(temp('bf16'), "wb") as f: f.write(adat) + + t = Tensor.empty(5, dtype=dtypes.bfloat16, device=f"disk:{temp('bf16')}").llvm().realize() + back = t.cast(dtypes.float32) + assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20) + +class TestHalfDtype(TestDType): DTYPE = dtypes.half + +class TestFloatDType(TestDType): DTYPE = dtypes.float + +class TestDoubleDtype(TestDType): DTYPE = dtypes.double + +class TestInt8Dtype(TestDType): + DTYPE = dtypes.int8 + @unittest.skipIf(getenv("CUDA",0)==1 or getenv("PTX", 0)==1, "cuda saturation works differently") + def test_int8_to_uint8_negative(self): _test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint8), dtypes.uint8, [255, 254, 253, 252]) + +class TestUint8Dtype(TestDType): + DTYPE = dtypes.uint8 + @unittest.skipIf(getenv("CUDA",0)==1 or getenv("PTX", 0)==1, "cuda saturation works differently") + def test_uint8_to_int8_overflow(self): _test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4]) + +@unittest.skipIf(Device.DEFAULT not in {"CPU", "TORCH"}, "only bitcast in CPU and TORCH") +class TestBitCast(unittest.TestCase): + def test_float32_bitcast_to_int32(self): _test_bitcast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.int32, [1065353216, 1073741824, 1077936128, 1082130432]) + @unittest.skipIf(Device.DEFAULT == "TORCH", "no uint32 in torch") + def test_float32_bitcast_to_uint32(self): _test_bitcast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.uint32, [1065353216, 1073741824, 1077936128, 1082130432]) + def test_int32_bitcast_to_float32(self): _test_bitcast(Tensor([1065353216, 1073741824, 1077936128, 1082130432], dtype=dtypes.int32), dtypes.float32, [1.0, 2.0, 3.0, 4.0]) + + # NOTE: these are the same as normal casts + def test_int8_bitcast_to_uint8(self): _test_bitcast(Tensor([-1, -2, -3, -4], dtype=dtypes.int8), dtypes.uint8, [255, 254, 253, 252]) + def test_uint8_bitcast_to_int8(self): _test_bitcast(Tensor([255, 254, 253, 252], dtype=dtypes.uint8), dtypes.int8, [-1, -2, -3, -4]) + @unittest.skipIf(Device.DEFAULT == "TORCH", "no uint64 in torch") + def test_int64_bitcast_to_uint64(self): _test_bitcast(Tensor([-1, -2, -3, -4], dtype=dtypes.int64), dtypes.uint64, [18446744073709551615, 18446744073709551614, 18446744073709551613, 18446744073709551612]) + @unittest.skipIf(Device.DEFAULT == "TORCH", "no uint64 in torch") + def test_uint64_bitcast_to_int64(self): _test_bitcast(Tensor([18446744073709551615, 18446744073709551614, 18446744073709551613, 18446744073709551612], dtype=dtypes.uint64), dtypes.int64, [-1, -2, -3, -4]) + + def test_shape_change_bitcast(self): + with self.assertRaises(AssertionError): + _test_bitcast(Tensor([100000], dtype=dtypes.float32), dtypes.uint8, [100000]) + +class TestInt16Dtype(TestDType): DTYPE = dtypes.int16 +class TestUint16Dtype(TestDType): DTYPE = dtypes.uint16 + +class TestInt32Dtype(TestDType): DTYPE = dtypes.int32 +class TestUint32Dtype(TestDType): DTYPE = dtypes.uint32 + +class TestInt64Dtype(TestDType): DTYPE = dtypes.int64 +class TestUint64Dtype(TestDType): DTYPE = dtypes.uint64 + +class TestBoolDtype(TestDType): DTYPE = dtypes.bool + +class TestEqStrDType(unittest.TestCase): + def test_image_ne(self): + if ImageDType is None: raise unittest.SkipTest("no ImageDType support") + assert dtypes.float == dtypes.float32, "float doesn't match?" + assert dtypes.imagef((1,2,4)) != dtypes.imageh((1,2,4)), "different image dtype doesn't match" + assert dtypes.imageh((1,2,4)) != dtypes.imageh((1,4,2)), "different shape doesn't match" + assert dtypes.imageh((1,2,4)) == dtypes.imageh((1,2,4)), "same shape matches" + assert isinstance(dtypes.imageh((1,2,4)), ImageDType) + def test_ptr_ne(self): + if PtrDType is None: raise unittest.SkipTest("no PtrDType support") + # TODO: is this the wrong behavior? + assert PtrDType(dtypes.float32) == dtypes.float32 + #assert PtrDType(dtypes.float32) == PtrDType(dtypes.float32) + #assert PtrDType(dtypes.float32) != dtypes.float32 + def test_strs(self): + if PtrDType is None: raise unittest.SkipTest("no PtrDType support") + self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))") + self.assertEqual(str(PtrDType(dtypes.float32)), "ptr.dtypes.float") + +if __name__ == '__main__': + unittest.main()