Skip to content

Commit

Permalink
add mypy + pre commit hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
geohot committed Nov 24, 2023
1 parent 6a7a232 commit c7bba5f
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 8 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ jobs:
- name: Get code size
- name: Train MNIST
run: PYTHONPATH="." python mnist.py
- name: Install torch for testing
run: pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
- name: Install mypy + torch for testing
run: pip install mypy torch --extra-index-url https://download.pytorch.org/whl/cpu
- name: Test ops / dtype / optim
run: |
PYTHONPATH="." python test/test_ops.py
PYTHONPATH="." python test/test_dtype.py
PYTHONPATH="." python test/test_optim.py
- name: Check types with mypy
run: mypy
15 changes: 15 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
repos:
- repo: local
hooks:
- id: tests
name: tests
entry: env PYTHONPATH="." pytest test/
language: system
always_run: true
pass_filenames: false
- id: mypy
name: mypy
entry: mypy
language: system
always_run: true
pass_filenames: false
9 changes: 9 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[mypy]
warn_unused_configs = True
files = teenygrad
ignore_missing_imports = True
check_untyped_defs = True
explicit_package_bases = True
warn_unreachable = True
warn_redundant_casts = True
warn_unused_ignores = True
2 changes: 1 addition & 1 deletion teenygrad/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def argfix(*x): return tuple(x[0]) if x and x[0].__class__ in (tuple, list) else
def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x
def flatten(l:Iterator): return [item for sublist in l for item in sublist]
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
def all_int(t: Tuple[Any, ...]) -> Tuple[int, ...]: return all(isinstance(s, int) for s in t)
def all_int(t: Tuple[Any, ...]) -> bool: return all(isinstance(s, int) for s in t)
def round_up(num, amt:int): return (num+amt-1)//amt * amt

@functools.lru_cache(maxsize=None)
Expand Down
5 changes: 4 additions & 1 deletion teenygrad/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ class LazyBuffer:

def __init__(self, buf: np.ndarray): self._np = buf

@property
def base(self): return self
@property
def dtype(self): return dtypes.from_np(self._np.dtype)
@property
Expand All @@ -21,7 +23,8 @@ def shape(self): return self._np.shape
def __repr__(self): return f"<LB {self.shape} {self.dtype}>"

def schedule(self, seen=None): return []
def is_unrealized_const(self): return False
def is_unrealized_contiguous_const(self): return False
def copy_to_device(self, device:str) -> LazyBuffer: return self

@staticmethod
def fromCPU(x): return LazyBuffer(x)
Expand Down
4 changes: 3 additions & 1 deletion teenygrad/ops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from enum import Enum, auto
from typing import Optional

class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto() # noqa: E702
Expand All @@ -10,4 +11,5 @@ class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(
class Device:
DEFAULT = "CPU"
_buffers = ["CPU"]
def canonicalize(x): return "CPU"
@staticmethod
def canonicalize(device:Optional[str]) -> str: return "CPU"
6 changes: 3 additions & 3 deletions teenygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ def __init__(self, data:Union[None, int, float, list, LazyBuffer, np.ndarray, by
data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item())
else:
data = LazyBuffer.fromCPU(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)
else: raise RuntimeError(f"can't create Tensor from {data} with type {type(data)}")

# data is a LazyBuffer, but it might be on the wrong device
if not isinstance(data, LazyBuffer): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
self.lazydata = data if data.device == device else data.copy_to_device(device)

def __repr__(self):
Expand Down Expand Up @@ -673,8 +673,8 @@ def _broadcasted(self, y:Union[Tensor, float], reverse:bool=False) -> Tuple[Tens
return (x, y)

def _to_float(self, x:Union[Tensor, float]):
return x.lazydata.base.op.arg if isinstance(x, Tensor) and x.lazydata.is_unrealized_const() and not x.requires_grad \
and x.lazydata.st.contiguous and self._broadcasted(x)[0].shape == self.shape else x
return x.lazydata.base.op.arg if isinstance(x, Tensor) and x.lazydata.is_unrealized_contiguous_const() \
and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x

def add(self, x:Union[Tensor, float], reverse=False) -> Tensor:
x = self._to_float(x)
Expand Down

0 comments on commit c7bba5f

Please sign in to comment.