From 2da66f23530640a325df352e7b8bb0f14dc796df Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sat, 23 Sep 2023 14:44:39 +0800 Subject: [PATCH] teenygrad up to date --- teenygrad/helpers.py | 4 ++- teenygrad/lazy.py | 22 +++++++------- teenygrad/mlops.py | 17 ++++++----- teenygrad/shape/symbolic.py | 1 + teenygrad/tensor.py | 57 +++++++++++++++++++++++-------------- 5 files changed, 61 insertions(+), 40 deletions(-) create mode 100644 teenygrad/shape/symbolic.py diff --git a/teenygrad/helpers.py b/teenygrad/helpers.py index 454c76b..ff40020 100644 --- a/teenygrad/helpers.py +++ b/teenygrad/helpers.py @@ -1,4 +1,5 @@ -from typing import Union, Tuple, Iterator, NamedTuple, Optional, Final +from typing import Union, Tuple, Iterator, NamedTuple, Optional, Final, Any +from typing_extensions import TypeGuard import os, functools import numpy as np from math import prod # noqa: F401 # pylint:disable=unused-import @@ -8,6 +9,7 @@ def argfix(*x): return tuple(x[0]) if x and x[0].__class__ in (tuple, list) else def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x def flatten(l:Iterator): return [item for sublist in l for item in sublist] def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python +def all_int(t: Tuple[Any, ...]) -> TypeGuard[Tuple[int, ...]]: return all(isinstance(s, int) for s in t) @functools.lru_cache(maxsize=None) def getenv(key, default=0): return type(default)(os.getenv(key, default)) diff --git a/teenygrad/lazy.py b/teenygrad/lazy.py index c2ad470..4e971fe 100644 --- a/teenygrad/lazy.py +++ b/teenygrad/lazy.py @@ -18,11 +18,8 @@ def __init__(self, buf): self._np = buf @property def shape(self): return self._np.shape - def contiguous(x): return x def realize(x): return x - def const(self, x) -> LazyBuffer: return LazyBuffer(np.full_like(self._np, x)) - @staticmethod def fromCPU(x): return LazyBuffer(x) def toCPU(self): return self._np @@ -33,13 +30,8 @@ def loadop(op, shape, dtype, device, arg=None, src=None) -> LazyBuffer: elif op == LoadOps.CONST: return LazyBuffer(np.full(shape, arg)) else: raise NotImplementedError(op) - # MovementOps - def reshape(self, arg): return LazyBuffer(self._np.reshape(arg)) - def expand(self, arg): return LazyBuffer(np.broadcast_to(self._np, arg)) - def shrink(self, arg): return LazyBuffer(self._np[tuple(slice(p[0], p[1], None) for p in arg)]) - def permute(self, arg): return LazyBuffer(self._np.transpose(arg)) - def pad(self, arg): return LazyBuffer(np.pad(self._np, arg)) - def stride(self, arg): return LazyBuffer(self._np[tuple(slice(None, None, i) for i in arg)]) + def contiguous(x): return x + def const(self, x) -> LazyBuffer: return LazyBuffer(np.full_like(self._np, x)) def e(self, op, *srcs): if op == UnaryOps.NEG: return LazyBuffer(-self._np) @@ -56,7 +48,15 @@ def e(self, op, *srcs): elif op == TernaryOps.WHERE: return LazyBuffer(np.where(self._np, srcs[0]._np, srcs[1]._np)) else: raise NotImplementedError(op) - def reduce_op(self, op, new_shape): + def r(self, op, new_shape): if op == ReduceOps.SUM: return LazyBuffer(self._np.sum(shape_to_axis(self.shape, new_shape), keepdims=True)) elif op == ReduceOps.MAX: return LazyBuffer(self._np.max(shape_to_axis(self.shape, new_shape), keepdims=True)) else: raise NotImplementedError(op) + + # MovementOps + def reshape(self, arg): return LazyBuffer(self._np.reshape(arg)) + def expand(self, arg): return LazyBuffer(np.broadcast_to(self._np, arg)) + def shrink(self, arg): return LazyBuffer(self._np[tuple(slice(p[0], p[1], None) for p in arg)]) + def permute(self, arg): return LazyBuffer(self._np.transpose(arg)) + def pad(self, arg): return LazyBuffer(np.pad(self._np, arg)) + def stride(self, arg): return LazyBuffer(self._np[tuple(slice(None, None, i) for i in arg)]) diff --git a/teenygrad/mlops.py b/teenygrad/mlops.py index 3ac549f..db8094d 100644 --- a/teenygrad/mlops.py +++ b/teenygrad/mlops.py @@ -1,9 +1,10 @@ import math -from typing import Tuple, Optional +from typing import Tuple, Optional, cast from teenygrad.helpers import argsort, DType from teenygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps from teenygrad.tensor import Function from teenygrad.lazy import LazyBuffer +from teenygrad.shape.symbolic import sint class Contiguous(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: return x.contiguous() @@ -88,20 +89,20 @@ def backward(self, grad_output:LazyBuffer) -> LazyBuffer: class Sum(Function): def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer: self.input_shape = x.shape - return x.reduce_op(ReduceOps.SUM, new_shape) + return x.r(ReduceOps.SUM, new_shape) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape) class Max(Function): def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer: - self.x, self.ret = x, x.reduce_op(ReduceOps.MAX, new_shape) + self.x, self.ret = x, x.r(ReduceOps.MAX, new_shape) return self.ret def backward(self, grad_output:LazyBuffer) -> LazyBuffer: # 1s in locations where the max was chosen (can be two locations) max_is_1s = self.x.const(1.0).e(BinaryOps.SUB, self.x.e(BinaryOps.CMPLT, self.ret.expand(self.x.shape))) - div = max_is_1s.reduce_op(ReduceOps.SUM, grad_output.shape).expand(self.x.shape) + div = max_is_1s.r(ReduceOps.SUM, grad_output.shape).expand(self.x.shape) return max_is_1s.e(BinaryOps.DIV, div).e(BinaryOps.MUL, grad_output.expand(self.x.shape)) # ************* binary ops ************* @@ -165,7 +166,7 @@ def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer: return x.expand(shape) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.reduce_op(ReduceOps.SUM, self.input_shape) + return grad_output.r(ReduceOps.SUM, self.input_shape) class Reshape(Function): def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer: @@ -192,12 +193,14 @@ def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.shrink(self.narg) class Shrink(Function): - def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer: + def forward(self, x:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer: self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)]) return x.shrink(arg) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.pad(self.narg) + assert all(isinstance(x[0], int) and isinstance(x[1], int) for x in self.narg), "symbolic shrink does not support backward" + # need this cast because mypy cannot narrow the type even with assert + return grad_output.pad(cast(Tuple[Tuple[int, int], ...], self.narg)) class Flip(Function): def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer: diff --git a/teenygrad/shape/symbolic.py b/teenygrad/shape/symbolic.py new file mode 100644 index 0000000..49a6374 --- /dev/null +++ b/teenygrad/shape/symbolic.py @@ -0,0 +1 @@ +sint = int \ No newline at end of file diff --git a/teenygrad/tensor.py b/teenygrad/tensor.py index 2ba5cd9..dac88ef 100644 --- a/teenygrad/tensor.py +++ b/teenygrad/tensor.py @@ -7,9 +7,10 @@ import numpy as np from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence -from teenygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes +from teenygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, prod, all_int from teenygrad.lazy import LazyBuffer from teenygrad.ops import Device, LoadOps +from teenygrad.shape.symbolic import sint # An instantiation of the Function is the Context class Function: @@ -75,7 +76,7 @@ def __hash__(self): return id(self) def device(self) -> str: return self.lazydata.device @property - def shape(self) -> Tuple[int, ...]: return self.lazydata.shape + def shape(self) -> Tuple[sint, ...]: return self.lazydata.shape @property def dtype(self) -> DType: return self.lazydata.dtype @@ -90,13 +91,13 @@ def assign(self, x) -> Tensor: # TODO: this is a hack for writing to DISK if self.device.startswith("DISK"): if x.__class__ is not Tensor: x = Tensor(x, device="CPU", dtype=self.dtype) - self.lazydata.realize().realized._copyin(x.numpy()) # type: ignore + self.lazydata.contiguous().realize().realized._copyin(x.numpy()) # type: ignore return self if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype) assert self.shape == x.shape and self.device == x.device, f"assign shape mismatch {self.shape} != {x.shape} or device mismatch {self.device} != {x.device}" assert not x.requires_grad # self requires_grad is okay? if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}") - if self.lazydata.realized is not None and not getenv("DISALLOW_ASSIGN"): x.lazydata.output_buffer = self.lazydata.realized + if self.dtype == x.dtype and self.lazydata.realized is not None and not getenv("DISALLOW_ASSIGN"): x.lazydata.output_buffer = self.lazydata.realized self.lazydata = x.lazydata return self @@ -121,7 +122,9 @@ def _loadop(op, sz, device:Optional[str]=None, dtype:Optional[DType]=None, arg=N return Tensor(LazyBuffer.loadop(op, [sz], Tensor.default_type if dtype is None else dtype, Device.canonicalize(device), arg), dtype=dtype, device=device, **kwargs) @staticmethod - def empty(*shape, **kwargs): return Tensor._loadop(LoadOps.EMPTY, math.prod(shape), **kwargs).reshape(shape) + def empty(*shape, **kwargs): + assert all_int(shape), f"cannot create with symbolic shape {shape}" + return Tensor._loadop(LoadOps.EMPTY, prod(shape), **kwargs).reshape(shape) _seed: int = int(time.time()) @staticmethod @@ -129,13 +132,14 @@ def manual_seed(seed=0): Tensor._seed = seed @staticmethod def rand(*shape, **kwargs): + assert all_int(shape), f"cannot create with symbolic shape {shape}" Tensor._seed += 1 - return Tensor._loadop(LoadOps.RAND, math.prod(shape), arg=Tensor._seed, **kwargs).reshape(shape) + return Tensor._loadop(LoadOps.RAND, prod(shape), arg=Tensor._seed, **kwargs).reshape(shape) # ***** creation helper functions ***** @staticmethod - def full(shape:Tuple[int, ...], fill_value, **kwargs): return Tensor(fill_value, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape) + def full(shape:Tuple[sint, ...], fill_value, **kwargs): return Tensor(fill_value, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape) @staticmethod def zeros(*shape, **kwargs): return Tensor.full(argfix(*shape), 0, **kwargs) @@ -173,22 +177,22 @@ def uniform(*shape, low=-1.0, high=1.0, **kwargs) -> Tensor: return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype) + low @staticmethod - def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul(math.prod(shape)**-0.5) + def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul(prod(shape)**-0.5) # https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform @staticmethod - def glorot_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul((6/(shape[0]+math.prod(shape[1:])))**0.5) + def glorot_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul((6/(shape[0]+prod(shape[1:])))**0.5) # https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_ @staticmethod def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor: - bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(math.prod(shape[1:])) + bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(shape[1:])) return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs) # https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_ @staticmethod def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor: - std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(math.prod(shape[1:])) + std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(shape[1:])) return Tensor.normal(*shape, mean=0.0, std=std, **kwargs) # ***** toposort and backward pass ***** @@ -224,11 +228,11 @@ def backward(self): def reshape(self, shape, *args) -> Tensor: new_shape = argfix(shape, *args) assert 0 not in new_shape, f"zeros not allowed in shape {new_shape}" - return mlops.Reshape.apply(self, shape=tuple([-math.prod(self.shape) // math.prod(new_shape) if s == -1 else s for s in new_shape])) + return mlops.Reshape.apply(self, shape=tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])) def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))])) def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args)) def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)]) - def shrink(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any(x != (0,s) for x,s in zip(arg, self.shape)) else self + def shrink(self, arg:Tuple[Tuple[sint, sint], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any(x != (0,s) for x,s in zip(arg, self.shape)) else self def pad(self, arg: Tuple[Tuple[int, int], ...], value:float=0) -> Tensor: ret = mlops.Pad.apply(self, arg=arg) if any(x != (0, 0) for x in arg) else self return ret if 0 == value else ret + mlops.Pad.apply(Tensor.ones_like(self), arg=arg).where(0, value) @@ -299,6 +303,7 @@ def normalize_int(e, i, dim_sz): if isinstance(s, int): dim_collapsed += 1 else: + assert isinstance(dim_shape, int), f"does not support symbolic shape {dim_shape}" final_shape.append(dim_shape) if isinstance(s, Tensor): tensors.append(s) @@ -326,7 +331,7 @@ def normalize_int(e, i, dim_sz): return ret # NOTE: using slice is discouraged and things should migrate to pad and shrink - def slice(self, arg:Sequence[Optional[Tuple[int, int]]], value:float=0) -> Tensor: + def slice(self, arg:Sequence[Optional[Tuple[int, sint]]], value:float=0) -> Tensor: arg_ = tuple([a if a is not None else (0,s) for s,a in zip(self.shape, arg)]) padding = tuple([(max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg_)]) return self.pad(padding, value=value).shrink(tuple([(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg_)])) @@ -367,6 +372,7 @@ def repeat(self, repeats): return self.reshape(new_shape).expand(expand_shape).reshape(final_shape) def chunk(self, num:int, dim:int) -> List[Tensor]: + assert all_int(self.shape), f"does not support symbolic shape {self.shape}" dim, step = dim + self.ndim if dim < 0 else dim, math.ceil(self.shape[dim]/num) slice_params = [[slice(None)]*dim + [slice(k, k + step)] for k in range(0, self.shape[dim], step)] return [self[tuple(sl)] for sl in slice_params] @@ -409,11 +415,13 @@ def max(self, axis=None, keepdim=False): return self._reduce(mlops.Max, axis, ke def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim)) def mean(self, axis=None, keepdim=False): + assert all_int(self.shape), "does not support symbolic shape" out = self.sum(axis=axis, keepdim=keepdim) - return out * (math.prod(out.shape)/math.prod(self.shape)) + return out * (prod(out.shape)/prod(self.shape)) def std(self, axis=None, keepdim=False, correction=1): + assert all_int(self.shape), "does not support symbolic shape" square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim) - return (square_sum / (math.prod(self.shape)/math.prod(square_sum.shape)-correction)).sqrt() + return (square_sum / (prod(self.shape)/prod(square_sum.shape)-correction)).sqrt() def _softmax(self, axis): m = self - self.max(axis=axis, keepdim=True) e = m.exp() @@ -429,8 +437,8 @@ def log_softmax(self, axis=-1): def argmax(self, axis=None, keepdim=False): if axis is None: - idx = (self == self.max(axis)) * Tensor.arange(math.prod(self.shape)-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape) - return math.prod(self.shape) - idx.max() - 1 + idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape) + return prod(self.shape) - idx.max() - 1 axis = axis + len(self.shape) if axis < 0 else axis m = self == self.max(axis=axis, keepdim=True) idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1)) @@ -441,6 +449,7 @@ def argmin(self, axis=None, keepdim=False): return (-self).argmax(axis=axis, kee def _pool(self, k_:Tuple[int, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1) -> Tensor: assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}" + assert all_int(self.shape), f"does not support symbolic shape {self.shape}" s_, d_ = make_pair(stride, len(k_)), make_pair(dilation, len(k_)) assert len(k_) == len(s_) and len(k_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}" slc_prefix, prefix, i_ = [(0,x) for x in self.shape[0:-len(k_)]], self.shape[0:-len(k_)], self.shape[-len(k_):] @@ -552,8 +561,12 @@ def tan(self): return self.sin() / self.cos() @staticmethod def _tri(r:int, c:int, k:int=0, **kwargs) -> Tensor: return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c) - def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype, device=self.device).where(self, Tensor.zeros_like(self)) - def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype, device=self.device).where(Tensor.zeros_like(self), self) + def triu(self, k:int=0) -> Tensor: + assert all_int(self.shape), f"does not support symbolic shape {self.shape}" + return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype, device=self.device).where(self, Tensor.zeros_like(self)) + def tril(self, k:int=0) -> Tensor: + assert all_int(self.shape), f"does not support symbolic shape {self.shape}" + return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype, device=self.device).where(Tensor.zeros_like(self), self) # ***** math functions (unary) ***** def trunc(self: Tensor) -> Tensor: return self.cast(dtypes.int32).contiguous().cast(self.dtype) @@ -693,6 +706,8 @@ def dropout(self, p=0.5) -> Tensor: return self * mask * (1/(1.0 - p)) def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor: + # NOTE: it works if key, value have symbolic shape + assert all_int(self.shape), f"does not support symbolic shape {self.shape}" if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool) if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), attn_mask) return (self @ key.transpose(-2,-1) / math.sqrt(self.shape[-1]) + attn_mask).softmax(-1).dropout(dropout_p) @ value @@ -716,7 +731,7 @@ def half(self) -> Tensor: return self.cast(dtypes.float16) @property def ndim(self) -> int: return len(self.shape) - def numel(self) -> int: return math.prod(self.shape) + def numel(self) -> sint: return prod(self.shape) def element_size(self) -> int: return self.dtype.itemsize def nbytes(self) -> int: return self.numel() * self.element_size() def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)