Skip to content

Commit

Permalink
Add decorator to disable gradient computation
Browse files Browse the repository at this point in the history
  • Loading branch information
parsiad committed Jan 30, 2024
1 parent b6ada09 commit 33e57b0
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 8 deletions.
3 changes: 2 additions & 1 deletion examples/mnist.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@
" loss.backward(opt=opt)\n",
" opt.step()\n",
" test_x = mpp.Constant(test_images_)\n",
" test_fx = model(test_x)\n",
" with mpp.no_grad():\n",
" test_fx = model(test_x)\n",
" pred_labels = np.argmax(test_fx.value, axis=1)\n",
" accuracy = (pred_labels == test_labels).mean().item()\n",
" accuracies.append(accuracy)"
Expand Down
4 changes: 3 additions & 1 deletion src/micrograd_pp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ._expr import Constant, Expr, Parameter, maximum, relu
from ._expr import Constant, Expr, Parameter, is_grad_enabled, maximum, no_grad, relu
from ._nn import Linear, ReLU, Sequential
from ._opt import SGD

Expand All @@ -13,6 +13,8 @@
"Sequential",
"SGD",
"datasets",
"is_grad_enabled",
"maximum",
"no_grad",
"relu",
)
33 changes: 28 additions & 5 deletions src/micrograd_pp/_expr.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,34 @@
from __future__ import annotations

import contextlib
import itertools
from abc import ABC, abstractmethod
from collections import deque
from functools import lru_cache
from typing import Any, Callable, Sequence
from typing import Any, Callable, Generator, Sequence

import numpy as np
import numpy.typing as npt


_grad_mode = True


@contextlib.contextmanager
def no_grad() -> Generator[None, None, None]:
"""Context manager that disables gradient computation."""
global _grad_mode
state = _grad_mode
_grad_mode = False
yield
_grad_mode = state


def is_grad_enabled() -> bool:
"""Determines whether or not gradient mode is enabled."""
return _grad_mode


class Expr:
"""Represents a differentiable expression in the graph.
Expand All @@ -34,11 +53,15 @@ def __init__(
requires_grad: bool | None = None,
) -> None:
self._value = value
self._children = tuple(children)
self._label = label
if requires_grad is None:
requires_grad = any(child._requires_grad for child in children)
self._requires_grad = requires_grad
if is_grad_enabled():
self._children = tuple(children)
if requires_grad is None:
requires_grad = any(child._requires_grad for child in children)
self._requires_grad = requires_grad
else:
self._children = ()
self._requires_grad = False
self._grad = None

def __repr__(self) -> str:
Expand Down
11 changes: 11 additions & 0 deletions tests/test_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,17 @@ def test_mult_scalar() -> None:
np.testing.assert_equal(a_.grad, grad)


def test_no_grad() -> None:
with mpp.no_grad():
a = np.random.randn(4, 1, 2)
b = np.random.randn(3, 2)
a_ = mpp.Parameter(a)
b_ = mpp.Parameter(b)
c_ = a_ * b_
with pytest.raises(ValueError):
c_.backward()


def test_pow() -> None:
a = np.random.randn(3, 2)
a_ = mpp.Parameter(a)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def test_mnist(batch_sz: int = 64, n_epochs: int = 3):
loss.backward(opt=opt)
opt.step()
test_x = mpp.Constant(test_images)
test_fx = model(test_x)
with mpp.no_grad():
test_fx = model(test_x)
pred_labels = np.argmax(test_fx.value, axis=1)
accuracy = (pred_labels == test_labels).mean().item()
print(f"Test accuracy at epoch {epoch}: {accuracy * 100:.2f}%")
Expand Down

0 comments on commit 33e57b0

Please sign in to comment.