Skip to content

Commit

Permalink
Add automatic broadcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
parsiad committed Jan 28, 2024
1 parent 0bc15c0 commit c2a32d1
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 42 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ model = mpp.Sequential(
def cross_entropy_loss(input_: mpp.Expr, target: mpp.Expr) -> mpp.Expr:
n, _ = input_.shape
input_max = input_.max(dim=1)
delta = input_ - input_max.expand(input_.shape)
delta = input_ - input_max
log_sum_exp = delta.exp().sum(dim=1).log().squeeze()
return (log_sum_exp - delta[np.arange(n), target]).sum() / n

Expand Down
22 changes: 9 additions & 13 deletions python/micrograd_pp/_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __add__(self, other: Any) -> Expr:
other = float(other)
if isinstance(other, float):
return _AddScalar(self, other)
return _Add(self, other)
return _Add(*_maybe_expand(self, other))

def __getitem__(self, index: Any) -> Expr:
return _Slice(self, index=index)
Expand All @@ -69,7 +69,7 @@ def __mul__(self, other: Any) -> Expr:
other = float(other)
if isinstance(other, float):
return _MultScalar(self, other)
return _Mult(self, other)
return _Mult(*_maybe_expand(self, other))

def __neg__(self) -> Expr:
return self * (-1.0)
Expand Down Expand Up @@ -331,7 +331,7 @@ def __init__(self, c: npt.NDArray, label: str | None = None) -> None:

def maximum(a: Expr, b: Expr) -> Expr:
"""The element-wise maximum of two expressions."""
return _Maximum(a, b)
return _Maximum(*_maybe_expand(a, b))


def relu(expr: Expr) -> Expr:
Expand All @@ -341,8 +341,6 @@ def relu(expr: Expr) -> Expr:

class _Add(Expr):
def __init__(self, a: Expr, b: Expr) -> None:
# TODO(parsiad): Support broadcasting
_raise_if_not_same_shape(a, b)
super().__init__(value=a._value + b._value, children=(a, b))
self._a = a
self._b = b
Expand Down Expand Up @@ -372,6 +370,7 @@ def _backward(self, grad: npt.NDArray) -> None:

class _Expand(Expr):
def __init__(self, a: Expr, shape: tuple[int, ...]) -> None:
# TODO(parsiad): Materializing a broadcast is expensive
super().__init__(value=np.broadcast_to(a._value, shape=shape), children=(a,))
self._a = a

Expand Down Expand Up @@ -434,7 +433,6 @@ def func() -> npt.NDArray:

class _Maximum(Expr):
def __init__(self, a: Expr, b: Expr) -> None:
_raise_if_not_same_shape(a, b)
super().__init__(value=np.maximum(a._value, b._value), children=(a, b))
self._a = a
self._b = b
Expand All @@ -446,8 +444,6 @@ def _backward(self, grad: npt.NDArray) -> None:

class _Mult(Expr):
def __init__(self, a: Expr, b: Expr) -> None:
# TODO(parsiad): Support broadcasting
_raise_if_not_same_shape(a, b)
super().__init__(value=a._value * b._value, children=(a, b))
self._a = a
self._b = b
Expand Down Expand Up @@ -543,8 +539,8 @@ def _backward(self, grad: npt.NDArray) -> None:
self._a.update_grad(lambda: grad.squeeze(axis=self._dim))


def _raise_if_not_same_shape(*exprs: Expr):
shape = next(iter(exprs)).shape
if not all(expr.shape == shape for expr in exprs):
msg = "Operands must be the same shape"
raise ValueError(msg)
def _maybe_expand(a: Expr, b: Expr) -> tuple[Expr, Expr]:
shape = np.broadcast_shapes(a.shape, b.shape)
a_ = a if a.shape == shape else _Expand(a, shape=shape)
b_ = b if b.shape == shape else _Expand(b, shape=shape)
return a_, b_
35 changes: 8 additions & 27 deletions tests/test_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,14 @@ def run_before_and_after_tests() -> Generator[None, None, None]:


def test_add() -> None:
a = np.random.randn(3, 2)
a = np.random.randn(4, 1, 2)
b = np.random.randn(3, 2)
a_ = mpp.Parameter(a)
b_ = mpp.Parameter(b)
c_ = a_ + b_
c_.backward()
grad = np.ones_like(a)
np.testing.assert_equal(a_.grad, grad)
np.testing.assert_equal(b_.grad, grad)


def test_add_bcast_fails() -> None:
a = np.random.randn(3, 2)
b = np.random.randn(3, 1)
a_ = mpp.Parameter(a)
b_ = mpp.Parameter(b)
with pytest.raises(ValueError):
c_ = a_ + b_
del c_
np.testing.assert_equal(a_.grad, np.full_like(a, 3.0))
np.testing.assert_equal(b_.grad, np.full_like(b, 4.0))


def test_add_scalar() -> None:
Expand Down Expand Up @@ -125,24 +114,16 @@ def test_maximum() -> None:


def test_mult() -> None:
a = np.random.randn(3, 2)
a = np.random.randn(4, 1, 2)
b = np.random.randn(3, 2)
a_ = mpp.Parameter(a)
b_ = mpp.Parameter(b)
c_ = a_ * b_
c_.backward()
np.testing.assert_equal(a_.grad, b)
np.testing.assert_equal(b_.grad, a)


def test_mult_bcast_fails() -> None:
a = np.random.randn(3, 2)
b = np.random.randn(3, 1)
a_ = mpp.Parameter(a)
b_ = mpp.Parameter(b)
with pytest.raises(ValueError):
c_ = a_ * b_
del c_
a_grad = np.broadcast_to(b.sum(axis=0, keepdims=True), shape=(4, 1, 2))
b_grad = np.broadcast_to(a.sum(axis=0), shape=(3, 2))
np.testing.assert_equal(a_.grad, a_grad)
np.testing.assert_equal(b_.grad, b_grad)


def test_mult_scalar() -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def run_before_and_after_tests():
def cross_entropy_loss(input_: mpp.Expr, target: mpp.Expr) -> mpp.Expr:
n, _ = input_.shape
input_max = input_.max(dim=1)
delta = input_ - input_max.expand(input_.shape)
delta = input_ - input_max
log_sum_exp = delta.exp().sum(dim=1).log().squeeze()
return (log_sum_exp - delta[np.arange(n), target]).sum() / n

Expand Down

0 comments on commit c2a32d1

Please sign in to comment.