Skip to content

Commit

Permalink
Reduce dimensions after max and sum by default
Browse files Browse the repository at this point in the history
  • Loading branch information
parsiad committed Jan 29, 2024
1 parent e32a3fd commit 76276a0
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 6 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ model = mpp.Sequential(

def cross_entropy_loss(input_: mpp.Expr, target: mpp.Expr) -> mpp.Expr:
n, _ = input_.shape
input_max = input_.max(dim=1)
input_max = input_.max(dim=1, keepdim=True)
delta = input_ - input_max
log_sum_exp = delta.exp().sum(dim=1).log().squeeze()
return (log_sum_exp - delta[np.arange(n), target]).sum() / n
Expand Down
18 changes: 14 additions & 4 deletions src/micrograd_pp/_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,15 +184,20 @@ def log(self) -> Expr:
"""Take the element-wise natural logarithm."""
return _Log(self)

def max(self, dim: int | tuple[int, ...] | None = None) -> Expr:
def max(self, dim: int | tuple[int, ...] | None = None, keepdim: bool = False) -> Expr:
"""Maximize across a dimension.
Parameters
----------
dim
Axis or axes along which to operate. By default, all axes are used.
keepdim
Whether the output retains the specified dimension(s)
"""
return _Max(a=self, dim=dim)
retval = _Max(self, dim=dim)
if not keepdim:
retval = _Squeeze(retval, dim=dim)
return retval

def set_label(self, label: str) -> None:
"""Set the expression label."""
Expand All @@ -208,15 +213,20 @@ def squeeze(self, dim: int | tuple[int, ...] | None = None) -> Expr:
"""
return _Squeeze(self, dim=dim)

def sum(self, dim: int | tuple[int, ...] | None = None) -> Expr:
def sum(self, dim: int | tuple[int, ...] | None = None, keepdim: bool = False) -> Expr:
"""Sum across one or more dimensions.
Parameters
----------
dim
Axis or axes along which to operate. By default, all axes are used.
keepdim
Whether the output retains the specified dimension(s)
"""
return _Sum(self, dim=dim)
retval = _Sum(self, dim=dim)
if not keepdim:
retval = _Squeeze(retval, dim=dim)
return retval

def transpose(self, dim0: int, dim1: int) -> Expr:
"""Transpose axes.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def run_before_and_after_tests():

def cross_entropy_loss(input_: mpp.Expr, target: mpp.Expr) -> mpp.Expr:
n, _ = input_.shape
input_max = input_.max(dim=1)
input_max = input_.max(dim=1, keepdim=True)
delta = input_ - input_max
log_sum_exp = delta.exp().sum(dim=1).log().squeeze()
return (log_sum_exp - delta[np.arange(n), target]).sum() / n
Expand Down

0 comments on commit 76276a0

Please sign in to comment.