diff --git a/examples/mnist.ipynb b/examples/mnist.ipynb index 669a3e1..bfda893 100644 --- a/examples/mnist.ipynb +++ b/examples/mnist.ipynb @@ -188,7 +188,8 @@ " loss.backward(opt=opt)\n", " opt.step()\n", " test_x = mpp.Constant(test_images_)\n", - " test_fx = model(test_x)\n", + " with mpp.no_grad():\n", + " test_fx = model(test_x)\n", " pred_labels = np.argmax(test_fx.value, axis=1)\n", " accuracy = (pred_labels == test_labels).mean().item()\n", " accuracies.append(accuracy)" diff --git a/src/micrograd_pp/__init__.py b/src/micrograd_pp/__init__.py index 48b65aa..3f54508 100644 --- a/src/micrograd_pp/__init__.py +++ b/src/micrograd_pp/__init__.py @@ -1,4 +1,4 @@ -from ._expr import Constant, Expr, Parameter, maximum, relu +from ._expr import Constant, Expr, Parameter, is_grad_enabled, maximum, no_grad, relu from ._nn import Linear, ReLU, Sequential from ._opt import SGD @@ -13,6 +13,8 @@ "Sequential", "SGD", "datasets", + "is_grad_enabled", "maximum", + "no_grad", "relu", ) diff --git a/src/micrograd_pp/_expr.py b/src/micrograd_pp/_expr.py index 984b229..68e1304 100644 --- a/src/micrograd_pp/_expr.py +++ b/src/micrograd_pp/_expr.py @@ -1,15 +1,34 @@ from __future__ import annotations +import contextlib import itertools from abc import ABC, abstractmethod from collections import deque from functools import lru_cache -from typing import Any, Callable, Sequence +from typing import Any, Callable, Generator, Sequence import numpy as np import numpy.typing as npt +_grad_mode = True + + +@contextlib.contextmanager +def no_grad() -> Generator[None, None, None]: + """Context manager that disables gradient computation.""" + global _grad_mode + state = _grad_mode + _grad_mode = False + yield + _grad_mode = state + + +def is_grad_enabled() -> bool: + """Determines whether or not gradient mode is enabled.""" + return _grad_mode + + class Expr: """Represents a differentiable expression in the graph. @@ -34,11 +53,15 @@ def __init__( requires_grad: bool | None = None, ) -> None: self._value = value - self._children = tuple(children) self._label = label - if requires_grad is None: - requires_grad = any(child._requires_grad for child in children) - self._requires_grad = requires_grad + if is_grad_enabled(): + self._children = tuple(children) + if requires_grad is None: + requires_grad = any(child._requires_grad for child in children) + self._requires_grad = requires_grad + else: + self._children = () + self._requires_grad = False self._grad = None def __repr__(self) -> str: diff --git a/tests/test_expr.py b/tests/test_expr.py index 539915c..462e4f4 100644 --- a/tests/test_expr.py +++ b/tests/test_expr.py @@ -136,6 +136,17 @@ def test_mult_scalar() -> None: np.testing.assert_equal(a_.grad, grad) +def test_no_grad() -> None: + with mpp.no_grad(): + a = np.random.randn(4, 1, 2) + b = np.random.randn(3, 2) + a_ = mpp.Parameter(a) + b_ = mpp.Parameter(b) + c_ = a_ * b_ + with pytest.raises(ValueError): + c_.backward() + + def test_pow() -> None: a = np.random.randn(3, 2) a_ = mpp.Parameter(a) diff --git a/tests/test_mnist.py b/tests/test_mnist.py index 4db867f..58315df 100644 --- a/tests/test_mnist.py +++ b/tests/test_mnist.py @@ -60,7 +60,8 @@ def test_mnist(batch_sz: int = 64, n_epochs: int = 3): loss.backward(opt=opt) opt.step() test_x = mpp.Constant(test_images) - test_fx = model(test_x) + with mpp.no_grad(): + test_fx = model(test_x) pred_labels = np.argmax(test_fx.value, axis=1) accuracy = (pred_labels == test_labels).mean().item() print(f"Test accuracy at epoch {epoch}: {accuracy * 100:.2f}%")