Skip to content

Commit

Permalink
ops test passes
Browse files Browse the repository at this point in the history
  • Loading branch information
geohot committed Nov 24, 2023
1 parent ac4e379 commit 553ac19
Show file tree
Hide file tree
Showing 5 changed files with 1,296 additions and 14 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,7 @@ jobs:
python-version: 3.11
- name: Install Dependencies
run: pip install numpy tqdm
- name: Test Ops
run: PYTHONPATH="." python test/test_ops.py
- name: Train MNIST
run: PYTHONPATH="." python mnist.py
4 changes: 3 additions & 1 deletion import_from_tinygrad.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
#!/usr/bin/env python3
import pathlib

FILES = ["tensor.py", "mlops.py", "nn/optim.py"]
FILES = ["tensor.py", "mlops.py", "nn/optim.py", "../test/test_ops.py"]
src = pathlib.Path("../tinygrad/tinygrad")
dest = pathlib.Path("teenygrad")

for f in FILES:
print("importing", f)
rd = open(src/f).read()
rd = rd.replace("from tinygrad.", "from teenygrad.")
rd = rd.replace("import tinygrad.", "import teenygrad.")
(dest/f).parent.mkdir(parents=True, exist_ok=True)
with open(dest/f, "w") as f:
f.write(rd)
14 changes: 9 additions & 5 deletions teenygrad/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os, functools
import numpy as np
from math import prod # noqa: F401 # pylint:disable=unused-import
from dataclasses import dataclass

def dedup(x): return list(dict.fromkeys(x)) # retains list orderi
def argfix(*x): return tuple(x[0]) if x and x[0].__class__ in (tuple, list) else x
Expand All @@ -17,7 +18,8 @@ def getenv(key, default=0): return type(default)(os.getenv(key, default))
DEBUG = getenv("DEBUG")
CI = os.getenv("CI", "") != ""

class DType(NamedTuple):
@dataclass(frozen=True, order=True)
class DType:
priority: int # this determines when things get upcasted
itemsize: int
name: str
Expand All @@ -29,10 +31,12 @@ class dtypes:
@staticmethod
def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name]
@staticmethod
def is_float(x: DType) -> bool: return x == dtypes.float32
float32: Final[DType] = DType(4, 4, "float", np.float32)
int32: Final[DType] = DType(2, 1, "int32", np.int32)
def is_float(x: DType) -> bool: return x in (dtypes.float32, dtypes.float64)
float32: Final[DType] = DType(10, 4, "float", np.float32)
float64: Final[DType] = DType(11, 8, "double", np.float64)
int32: Final[DType] = DType(5, 4, "int", np.int32)
int64: Final[DType] = DType(7, 8, "long", np.int64)
bool: Final[DType] = DType(0, 1, "bool", np.bool_)
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and not v.__class__ == staticmethod}

ImageDType, IMAGE = None, None # junk to remove
ImageDType, IMAGE = None, 0 # junk to remove
19 changes: 11 additions & 8 deletions teenygrad/lazy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations
from typing import Tuple
from teenygrad.helpers import dtypes
from teenygrad.helpers import DType, dtypes
from teenygrad.ops import UnaryOps, BinaryOps, ReduceOps, TernaryOps, LoadOps
import numpy as np

Expand All @@ -10,10 +9,11 @@ def toCPU(self): return self.x

class LazyBuffer:
device = "CPU"
dtype = dtypes.float32

def __init__(self, buf): self._np = buf
def __init__(self, buf: np.ndarray): self._np = buf

@property
def dtype(self): return dtypes.from_np(self._np.dtype)
@property
def realized(self): return RawCPUBuffer(self._np)
@property
Expand All @@ -27,14 +27,17 @@ def fromCPU(x): return LazyBuffer(x)

@staticmethod
def loadop(op, shape, dtype, device, arg=None, src=None) -> LazyBuffer:
if op == LoadOps.RAND: return LazyBuffer(np.random.default_rng(arg).random(size=shape, dtype=np.float32))
elif op == LoadOps.CONST: return LazyBuffer(np.full(shape, arg))
if op == LoadOps.RAND: return LazyBuffer(np.random.default_rng(arg).random(size=shape, dtype=dtype.np))
elif op == LoadOps.CONST: return LazyBuffer(np.full(shape, arg, dtype=dtype.np))
elif op == LoadOps.EMPTY: return LazyBuffer(np.empty(shape, dtype=dtype.np))
else: raise NotImplementedError(op)

def contiguous(x): return x
def const(self, x) -> LazyBuffer: return LazyBuffer(np.full_like(self._np, x))

def e(self, op, *srcs):
def cast(self, dtype:DType, bitcast:bool=False): return LazyBuffer(self._np.astype(dtype.np))

def e(self, op, *srcs:LazyBuffer):
if op == UnaryOps.NEG: return LazyBuffer(-self._np)
elif op == UnaryOps.EXP2: return LazyBuffer(np.exp2(self._np))
elif op == UnaryOps.LOG2: return LazyBuffer(np.log2(self._np))
Expand All @@ -43,7 +46,7 @@ def e(self, op, *srcs):
elif op == BinaryOps.ADD: return LazyBuffer(self._np + srcs[0]._np)
elif op == BinaryOps.SUB: return LazyBuffer(self._np - srcs[0]._np)
elif op == BinaryOps.MUL: return LazyBuffer(self._np * srcs[0]._np)
elif op == BinaryOps.DIV: return LazyBuffer(self._np / srcs[0]._np)
elif op == BinaryOps.DIV: return LazyBuffer((self._np / srcs[0]._np).astype(max(self.dtype, srcs[0].dtype).np))
elif op == BinaryOps.MAX: return LazyBuffer(np.maximum(self._np, srcs[0]._np))
elif op == BinaryOps.CMPLT: return LazyBuffer(self._np < srcs[0]._np)
elif op == TernaryOps.WHERE: return LazyBuffer(np.where(self._np, srcs[0]._np, srcs[1]._np))
Expand Down
Loading

0 comments on commit 553ac19

Please sign in to comment.