diff --git a/README.md b/README.md index 8fe5d83..c34473f 100644 --- a/README.md +++ b/README.md @@ -74,7 +74,7 @@ model = mpp.Sequential( def cross_entropy_loss(input_: mpp.Expr, target: mpp.Expr) -> mpp.Expr: n, _ = input_.shape input_max = input_.max(dim=1) - delta = input_ - input_max.expand(input_.shape) + delta = input_ - input_max log_sum_exp = delta.exp().sum(dim=1).log().squeeze() return (log_sum_exp - delta[np.arange(n), target]).sum() / n diff --git a/python/micrograd_pp/_expr.py b/python/micrograd_pp/_expr.py index 345bafd..71c9ce4 100644 --- a/python/micrograd_pp/_expr.py +++ b/python/micrograd_pp/_expr.py @@ -56,7 +56,7 @@ def __add__(self, other: Any) -> Expr: other = float(other) if isinstance(other, float): return _AddScalar(self, other) - return _Add(self, other) + return _Add(*_maybe_expand(self, other)) def __getitem__(self, index: Any) -> Expr: return _Slice(self, index=index) @@ -69,7 +69,7 @@ def __mul__(self, other: Any) -> Expr: other = float(other) if isinstance(other, float): return _MultScalar(self, other) - return _Mult(self, other) + return _Mult(*_maybe_expand(self, other)) def __neg__(self) -> Expr: return self * (-1.0) @@ -331,7 +331,7 @@ def __init__(self, c: npt.NDArray, label: str | None = None) -> None: def maximum(a: Expr, b: Expr) -> Expr: """The element-wise maximum of two expressions.""" - return _Maximum(a, b) + return _Maximum(*_maybe_expand(a, b)) def relu(expr: Expr) -> Expr: @@ -341,8 +341,6 @@ def relu(expr: Expr) -> Expr: class _Add(Expr): def __init__(self, a: Expr, b: Expr) -> None: - # TODO(parsiad): Support broadcasting - _raise_if_not_same_shape(a, b) super().__init__(value=a._value + b._value, children=(a, b)) self._a = a self._b = b @@ -372,6 +370,7 @@ def _backward(self, grad: npt.NDArray) -> None: class _Expand(Expr): def __init__(self, a: Expr, shape: tuple[int, ...]) -> None: + # TODO(parsiad): Materializing a broadcast is expensive super().__init__(value=np.broadcast_to(a._value, shape=shape), children=(a,)) self._a = a @@ -434,7 +433,6 @@ def func() -> npt.NDArray: class _Maximum(Expr): def __init__(self, a: Expr, b: Expr) -> None: - _raise_if_not_same_shape(a, b) super().__init__(value=np.maximum(a._value, b._value), children=(a, b)) self._a = a self._b = b @@ -446,8 +444,6 @@ def _backward(self, grad: npt.NDArray) -> None: class _Mult(Expr): def __init__(self, a: Expr, b: Expr) -> None: - # TODO(parsiad): Support broadcasting - _raise_if_not_same_shape(a, b) super().__init__(value=a._value * b._value, children=(a, b)) self._a = a self._b = b @@ -543,8 +539,8 @@ def _backward(self, grad: npt.NDArray) -> None: self._a.update_grad(lambda: grad.squeeze(axis=self._dim)) -def _raise_if_not_same_shape(*exprs: Expr): - shape = next(iter(exprs)).shape - if not all(expr.shape == shape for expr in exprs): - msg = "Operands must be the same shape" - raise ValueError(msg) +def _maybe_expand(a: Expr, b: Expr) -> tuple[Expr, Expr]: + shape = np.broadcast_shapes(a.shape, b.shape) + a_ = a if a.shape == shape else _Expand(a, shape=shape) + b_ = b if b.shape == shape else _Expand(b, shape=shape) + return a_, b_ diff --git a/tests/test_expr.py b/tests/test_expr.py index cadc936..539915c 100644 --- a/tests/test_expr.py +++ b/tests/test_expr.py @@ -16,25 +16,14 @@ def run_before_and_after_tests() -> Generator[None, None, None]: def test_add() -> None: - a = np.random.randn(3, 2) + a = np.random.randn(4, 1, 2) b = np.random.randn(3, 2) a_ = mpp.Parameter(a) b_ = mpp.Parameter(b) c_ = a_ + b_ c_.backward() - grad = np.ones_like(a) - np.testing.assert_equal(a_.grad, grad) - np.testing.assert_equal(b_.grad, grad) - - -def test_add_bcast_fails() -> None: - a = np.random.randn(3, 2) - b = np.random.randn(3, 1) - a_ = mpp.Parameter(a) - b_ = mpp.Parameter(b) - with pytest.raises(ValueError): - c_ = a_ + b_ - del c_ + np.testing.assert_equal(a_.grad, np.full_like(a, 3.0)) + np.testing.assert_equal(b_.grad, np.full_like(b, 4.0)) def test_add_scalar() -> None: @@ -125,24 +114,16 @@ def test_maximum() -> None: def test_mult() -> None: - a = np.random.randn(3, 2) + a = np.random.randn(4, 1, 2) b = np.random.randn(3, 2) a_ = mpp.Parameter(a) b_ = mpp.Parameter(b) c_ = a_ * b_ c_.backward() - np.testing.assert_equal(a_.grad, b) - np.testing.assert_equal(b_.grad, a) - - -def test_mult_bcast_fails() -> None: - a = np.random.randn(3, 2) - b = np.random.randn(3, 1) - a_ = mpp.Parameter(a) - b_ = mpp.Parameter(b) - with pytest.raises(ValueError): - c_ = a_ * b_ - del c_ + a_grad = np.broadcast_to(b.sum(axis=0, keepdims=True), shape=(4, 1, 2)) + b_grad = np.broadcast_to(a.sum(axis=0), shape=(3, 2)) + np.testing.assert_equal(a_.grad, a_grad) + np.testing.assert_equal(b_.grad, b_grad) def test_mult_scalar() -> None: diff --git a/tests/test_mnist.py b/tests/test_mnist.py index d6d88d6..215486e 100644 --- a/tests/test_mnist.py +++ b/tests/test_mnist.py @@ -13,7 +13,7 @@ def run_before_and_after_tests(): def cross_entropy_loss(input_: mpp.Expr, target: mpp.Expr) -> mpp.Expr: n, _ = input_.shape input_max = input_.max(dim=1) - delta = input_ - input_max.expand(input_.shape) + delta = input_ - input_max log_sum_exp = delta.exp().sum(dim=1).log().squeeze() return (log_sum_exp - delta[np.arange(n), target]).sum() / n