Skip to content

Commit

Permalink
Add batch normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
parsiad committed Jan 31, 2024
1 parent 33e57b0 commit a015b00
Show file tree
Hide file tree
Showing 8 changed files with 232 additions and 10 deletions.
12 changes: 6 additions & 6 deletions examples/mnist.ipynb

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion src/micrograd_pp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from ._expr import Constant, Expr, Parameter, is_grad_enabled, maximum, no_grad, relu
from ._nn import Linear, ReLU, Sequential
from ._nn import BatchNorm1d, Linear, ReLU, Sequential, eval, is_eval
from ._opt import SGD

from . import datasets

__all__ = (
"BatchNorm1d",
"Constant",
"Expr",
"Linear",
Expand All @@ -13,6 +14,8 @@
"Sequential",
"SGD",
"datasets",
"eval",
"is_eval",
"is_grad_enabled",
"maximum",
"no_grad",
Expand Down
33 changes: 33 additions & 0 deletions src/micrograd_pp/_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import numpy as np
import numpy.typing as npt

from ._util import n_samples


_grad_mode = True

Expand Down Expand Up @@ -222,6 +224,21 @@ def max(self, dim: int | tuple[int, ...] | None = None, keepdim: bool = False) -
retval = _Squeeze(retval, dim=dim)
return retval

def mean(self, dim: int | tuple[int, ...] | None = None, keepdim: bool = False) -> Expr:
"""Mean across one or more dimensions.
Parameters
----------
dim
Axis or axes along which to operate. By default, all axes are used.
keepdim
Whether the output retains the specified dimension(s)
"""
retval = _Sum(self, dim=dim) / n_samples(dim=dim, shape=self.shape)
if not keepdim:
retval = _Squeeze(retval, dim=dim)
return retval

def set_label(self, label: str) -> None:
"""Set the expression label."""
self._label = label
Expand Down Expand Up @@ -288,6 +305,22 @@ def unsqueeze(self, dim: int) -> Expr:
"""
return _Unsqueeze(self, dim=dim)

def var(self, dim: int | tuple[int, ...] | None = None, keepdim: bool = False) -> Expr:
"""Variance across one or more dimensions.
Parameters
----------
dim
Axis or axes along which to operate. By default, all axes are used.
keepdim
Whether the output retains the specified dimension(s)
"""
delta = self - self.mean(dim=dim, keepdim=True)
retval = (delta * delta).mean(dim=dim, keepdim=True)
if not keepdim:
retval = _Squeeze(retval, dim=dim)
return retval

@property
def dtype(self) -> npt.DTypeLike:
"""Data type."""
Expand Down
106 changes: 104 additions & 2 deletions src/micrograd_pp/_nn.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,116 @@
import contextlib
from collections.abc import Callable
from typing import Any
from typing import Any, Generator

import numpy as np

from ._expr import Expr, Parameter, relu
from ._expr import Constant, Expr, Parameter, relu
from ._util import n_samples


Module = Callable[[Expr], Expr]


_eval_mode = False


@contextlib.contextmanager
def eval() -> Generator[None, None, None]:
"""Context manager to switch to eval mode."""
global _eval_mode
state = _eval_mode
_eval_mode = True
yield
_eval_mode = state


def is_eval() -> bool:
"""Determines whether or not eval mode is enabled."""
return _eval_mode


class BatchNorm1d:
"""Batch normalization.
Parameters
----------
num_features
Number of features
affine
Whether to use learnable scale and shift parameters
dtype
Data type for running mean and variance and scale and shift parameters
eps
When standardizing, this quantity is added to the denominator for numerical stability
momentum
Momentum used for the running mean and variance computations (if None, an ordinary average is computed)
track_running_stats
Whether to keep a running mean and variance
"""

def __init__(
self,
num_features: int,
affine: bool = True,
dtype: type = np.float32,
eps: float = 1e-5,
momentum: float | None = 0.1,
track_running_stats: bool = True,
) -> None:
self._eps = eps
self._momentum = momentum
self._num_features = num_features
if track_running_stats:
self._running_mean = np.zeros((num_features,), dtype=dtype)
self._running_var = np.ones((num_features,), dtype=dtype)
else:
self._running_mean = None
self._running_var = None
if affine:
self._scale = Parameter(np.ones((num_features,), dtype=dtype))
self._shift = Parameter(np.zeros((num_features,), dtype=dtype))
else:
self._scale = None
self._shift = None
self._n = 0

def __call__(self, x: Expr) -> Expr:
dim = (0,) + tuple(range(2, x.ndim))
if self._running_mean is not None and self._running_var is not None and is_eval():
mean = Constant(self._running_mean)
var = Constant(self._running_var)
else:
mean = x.mean(dim=dim)
var = x.var(dim=dim)
if self._running_mean is not None and self._running_var is not None:
increment = n_samples(dim, x.shape)
n_new = self._n + increment
if self._momentum is None:
a = self._n / n_new
b = increment / n_new
else:
a = 1.0 - self._momentum
b = self._momentum
self._running_mean = a * self._running_mean + b * mean.value
self._running_var = a * self._running_var + b * var.value
self._n = n_new
shape = (1, x.shape[1]) + ((1,) * (x.ndim - 2))
mean = mean.expand(shape)
var = var.expand(shape)
x_norm = (x - mean) / ((var + self._eps) ** 0.5)
if self._scale is not None and self._shift is not None:
return self._scale * x_norm + self._shift
else:
return x_norm

def __repr__(self) -> str:
return (
f"BatchNorm1d({self._num_features}, x={self._eps=}, momentum={self._momentum}, "
f"affine={self._scale is not None and self._shift is not None}, "
f"track_running_stats={self._running_mean is not None and self._running_var is not None})"
)


class Linear:
"""Linear layer.
Expand Down
9 changes: 9 additions & 0 deletions src/micrograd_pp/_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import numpy as np


def n_samples(dim: int | tuple[int, ...] | None, shape: tuple[int, ...]) -> int:
if isinstance(dim, int):
return shape[dim]
if dim is None:
return np.prod(shape).item()
return np.prod([shape[d] for d in dim]).item()
20 changes: 20 additions & 0 deletions tests/test_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,17 @@ def test_maximum() -> None:
np.testing.assert_equal(b_.grad, ~grad)


@pytest.mark.parametrize("dim", DIMS)
def test_mean(dim: int | tuple[int, ...] | None) -> None:
dim = (0, 2)
a = np.random.randn(4, 3, 2)
a_ = mpp.Parameter(a)
b_ = a_.mean(dim=dim)
b_.backward()
grad = np.ones_like(a) / np.ones_like(a).sum(axis=dim, keepdims=True)
np.testing.assert_equal(a_.grad, grad)


def test_mult() -> None:
a = np.random.randn(4, 1, 2)
b = np.random.randn(3, 2)
Expand Down Expand Up @@ -205,3 +216,12 @@ def test_unsqueeze(dim: int) -> None:
b_.backward()
grad = np.ones_like(a)
np.testing.assert_equal(a_.grad, grad)


@pytest.mark.parametrize("dim", DIMS)
def test_var(dim: int | tuple[int, ...] | None) -> None:
a = np.random.randn(4, 3, 2)
b = a.var(axis=dim)
a_ = mpp.Constant(a)
b_ = a_.var(dim=dim)
np.testing.assert_allclose(b_.value, b)
3 changes: 2 additions & 1 deletion tests/test_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def test_mnist(batch_sz: int = 64, n_epochs: int = 3):
# Feedforward neural network
model = mpp.Sequential(
mpp.Linear(28 * 28, 128),
mpp.BatchNorm1d(128),
mpp.ReLU(),
mpp.Linear(128, 10),
)
Expand All @@ -60,7 +61,7 @@ def test_mnist(batch_sz: int = 64, n_epochs: int = 3):
loss.backward(opt=opt)
opt.step()
test_x = mpp.Constant(test_images)
with mpp.no_grad():
with mpp.eval(), mpp.no_grad():
test_fx = model(test_x)
pred_labels = np.argmax(test_fx.value, axis=1)
accuracy = (pred_labels == test_labels).mean().item()
Expand Down
54 changes: 54 additions & 0 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Generator

import numpy as np
import pytest

import micrograd_pp as mpp

BATCH_SZ = 64
NUM_FEATURES = 10


@pytest.fixture(autouse=True)
def run_before_and_after_tests() -> Generator[None, None, None]:
np.random.seed(0)
yield


@pytest.mark.parametrize("momentum", [0.1, None])
def test_batch_norm_1d_track_running_stats(momentum: float) -> None:
num_iters = 1_000
shift = np.random.randn(10)
scale = np.random.randn(10)
bn = mpp.BatchNorm1d(NUM_FEATURES, affine=False, momentum=momentum)
for _ in range(num_iters):
x = scale * np.random.randn(BATCH_SZ, NUM_FEATURES) + shift
x_ = mpp.Constant(x)
bn(x_)
assert bn._running_mean is not None
assert bn._running_var is not None
np.testing.assert_allclose(bn._running_mean, shift, atol=0.1, rtol=0.0)
np.testing.assert_allclose(bn._running_var, scale * scale, atol=0.1, rtol=0.0)


def test_batch_norm_1d_standardize() -> None:
shift = np.random.randn(10)
scale = np.random.randn(10)
bn = mpp.BatchNorm1d(NUM_FEATURES, affine=False)
x = scale * np.random.randn(BATCH_SZ, NUM_FEATURES) + shift
x_ = mpp.Constant(x)
y_ = bn(x_)
np.testing.assert_allclose(y_.value.mean(axis=0), 0.0, atol=1e-6, rtol=0.0)
np.testing.assert_allclose(y_.value.var(axis=0), 1.0, atol=1e-3, rtol=0.0)


def test_batch_norm_1d_eval() -> None:
shift = np.random.randn(10)
scale = np.random.randn(10)
bn = mpp.BatchNorm1d(NUM_FEATURES, affine=False)
x = scale * np.random.randn(BATCH_SZ, NUM_FEATURES) + shift
x_ = mpp.Constant(x)
with mpp.eval():
y_ = bn(x_)
# The input should be close to the output since the batch norm scale and shift are 1 and 0 at initialization
np.testing.assert_allclose(x_.value, y_.value, atol=1e-4, rtol=0.0)

0 comments on commit a015b00

Please sign in to comment.