From d29dc3fc828d924cafad8e42461d07e9c9e789cc Mon Sep 17 00:00:00 2001 From: Sam Ritchie Date: Thu, 8 Aug 2024 09:54:05 -0400 Subject: [PATCH] delete perturbed --- CHANGELOG.md | 8 + src/emmy/abstract/function.cljc | 13 +- src/emmy/autodiff.cljc | 401 ++++++++++++++++++++++++++++++ src/emmy/calculus/derivative.cljc | 33 +-- src/emmy/collection.cljc | 6 +- src/emmy/dual.cljc | 130 ++++++---- src/emmy/env.cljc | 6 +- src/emmy/matrix.cljc | 1 - src/emmy/operator.cljc | 1 - src/emmy/polynomial.cljc | 4 - src/emmy/quaternion.cljc | 6 - src/emmy/series.cljc | 2 - src/emmy/structure.cljc | 1 - src/emmy/tape.cljc | 397 +---------------------------- test/emmy/collection_test.cljc | 14 -- test/emmy/dual_test.cljc | 82 +++--- test/emmy/polynomial_test.cljc | 5 - test/emmy/tape_test.cljc | 21 +- 18 files changed, 559 insertions(+), 572 deletions(-) create mode 100644 src/emmy/autodiff.cljc diff --git a/CHANGELOG.md b/CHANGELOG.md index 7281d034..68a13a95 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,14 @@ ## [unreleased] +- #182: + + - moves the generic implementations for `TapeCell` and `Dual` to `emmy.autodiff` + + - moves `emmy.calculus.derivative` to `emmy.dual/derivative` + + - removes `emmy.dual/perturbed?` from `IPerturbed`, as this is no longer used. + - #180 renames `emmy.differential` to `emmy.dual`, since the file now contains a proper dual number implementation, not a truncated multivariate power series. diff --git a/src/emmy/abstract/function.cljc b/src/emmy/abstract/function.cljc index 283ed972..26a65b2f 100644 --- a/src/emmy/abstract/function.cljc +++ b/src/emmy/abstract/function.cljc @@ -12,7 +12,8 @@ (:refer-clojure :exclude [name]) (:require #?(:clj [clojure.pprint :as pprint]) [emmy.abstract.number :as an] - [emmy.dual :as d] + [emmy.autodiff :as ad] + [emmy.dual :as dual] [emmy.function :as f] [emmy.generic :as g] [emmy.matrix :as m] @@ -268,9 +269,9 @@ [f primal-s tag] (fn ([] 0) - ([tangent] (d/bundle-element (apply f primal-s) tangent tag)) + ([tangent] (dual/bundle-element (apply f primal-s) tangent tag)) ([tangent [x path _]] - (let [dx (d/tangent x tag)] + (let [dx (dual/tangent x tag)] (if (g/numeric-zero? dx) tangent (let [partial (literal-partial f path)] @@ -318,9 +319,9 @@ that input." [f s tag dx] (let [fold-fn (cond (tape/tape? dx) reverse-mode-fold - (d/dual? dx) forward-mode-fold + (dual/dual? dx) forward-mode-fold :else (u/illegal "No tape or differential inputs.")) - primal-s (s/mapr (fn [x] (tape/primal-of x tag)) s)] + primal-s (s/mapr (fn [x] (ad/primal-of x tag)) s)] (s/fold-chain (fold-fn f primal-s tag) s))) (defn- check-argument-type @@ -355,7 +356,7 @@ (if-let [[tag dx] (s/fold-chain (fn ([] []) - ([acc] (apply tape/tag+perturbation acc)) + ([acc] (apply ad/tag+perturbation acc)) ([acc [d]] (conj acc d))) s)] (literal-derivative f s tag dx) diff --git a/src/emmy/autodiff.cljc b/src/emmy/autodiff.cljc new file mode 100644 index 00000000..af746104 --- /dev/null +++ b/src/emmy/autodiff.cljc @@ -0,0 +1,401 @@ +^#:nextjournal.clerk +{:toc true + :visibility :hide-ns} +(ns emmy.autodiff + (:require [emmy.dual :as d] + [emmy.generic :as g] + [emmy.tape :as t] + [emmy.util :as u] + [emmy.value :as v])) + +;; ## Implementations + +(defn tag-of + "More permissive version of [[emmy.tape/tape-tag]] that returns `nil` when + passed a non-perturbation." + [x] + (cond (t/tape? x) (t/tape-tag x) + (d/dual? x) (d/tag x) + :else nil)) + +(defn primal-of + "More permissive version of [[emmy.tape/tape-primal]] that returns `v` when passed a + non-perturbation." + ([v] + (primal-of v (tag-of v))) + ([v tag] + (cond (t/tape? v) (t/tape-primal v tag) + (d/dual? v) (d/primal v tag) + :else v))) + +(defn inner-tag + "Given any number of `tags`, returns the tag most recently bound + via [[with-active-tag]] (i.e., the tag connected with the _innermost_ call + to [[with-active-tag]]). + + If none of the tags are bound, returns `(apply max tags)`." + [& tags] + (or (some (apply hash-set tags) + d/*active-tags*) + (apply max tags))) + +(defn tag+perturbation + "Given any number of `dxs`, returns a pair of the form + + [ ] + + containing the tag and instance of [[emmy.dual/Dual]] or [[TapeCell]] + associated with the inner-most call to [[with-active-tag]] in the current call + stack. + + If none of `dxs` has an active tag, returns `nil`." + ([& dxs] + (let [xform (map + (fn [dx] + (when-let [t (tag-of dx)] + [t dx]))) + m (into {} xform dxs)] + (when (seq m) + (let [tag (apply inner-tag (keys m))] + [tag (m tag)]))))) + +(defn deep-primal + "Version of [[tape-primal]] that will descend recursively into any perturbation + instance returned by [[tape-primal]] or [[emmy.dual/primal]] until + encountering a non-perturbation. + + Given a non-perturbation, acts as identity." + ([v] + (cond (t/tape? v) (recur (t/tape-primal v)) + (d/dual? v) (recur (d/primal v)) + :else v))) + +;; ## Lifted Functions + +;; [[lift-1]] and [[lift-2]] "lift", or augment, unary or binary functions with +;; the ability to handle [[emmy.dual/Dual]] and [[TapeCell]] instances in +;; addition to whatever other types they previously supported. +;; +;; Forward-mode support for [[emmy.dual/Dual]] is an implementation of the +;; single and multivariable Taylor series expansion methods discussed at the +;; beginning of [[emmy.dual]]. +;; +;; To support reverse-mode automatic differentiation, When a unary or binary +;; function `f` encounters a [[TapeCell]] `x` (and `y` in the binary case) it +;; needs to return a new [[TapeCell]] with: +;; +;; - the same tag +;; - a fresh, unique ID +;; - a primal value `(f x)` (or `(f x y)`) + +;; - a map of each input to the partial of `f` with respect to that input. +;; So, `{x ((D f) x)}` in the unary case, and +;; +;; ```clojure +;; {x (((partial 0) f) x y) +;; y (((partial 1) f) x y)} +;; ```` +;; +;; in the binary case. + +;; There is a subtlety here, noted in the docstrings below. [[lift-1]] +;; and [[lift-2]] really are able to lift functions like [[clojure.core/+]] that +;; can't accept [[emmy.dual/Dual]] and [[TapeCell]]s. But the first-order +;; derivatives that you have to supply _do_ have to be able to take instances of +;; these types. +;; +;; This is because, for example, the [[emmy.dual/tangent]] of [[emmy.dual/Dual]] +;; might still be an [[emmy.dual/Dual]], and will hit the first-order derivative +;; via the chain rule. +;; +;; Magically this will all Just Work if you pass an already-lifted function, or +;; a function built out of already-lifted components, as `df:dx` or `df:dy`. + +(defn lift-1 + "Given: + + - some unary function `f` + - a function `df:dx` that computes the derivative of `f` with respect to its + single argument + + Returns a new unary function that operates on both the original type of + `f`, [[TapeCell]] and [[emmy.dual/Dual]] instances. + + If called without `df:dx`, `df:dx` defaults to `(f :dfdx)`; this will return + the derivative registered to a generic function defined + with [[emmy.util.def/defgeneric]]. + + NOTE: `df:dx` has to ALREADY be able to handle [[TapeCell]] + and [[emmy.dual/Dual]] instances. The best way to accomplish this is by + building `df:dx` out of already-lifted functions, and declaring them by + forward reference if you need to." + ([f] + (if-let [df:dx (f :dfdx)] + (lift-1 f df:dx) + (u/illegal + "No df:dx supplied for `f` or registered generically."))) + ([f df:dx] + (fn call [x] + (cond (t/tape? x) + (let [primal (t/tape-primal x)] + (t/make (t/tape-tag x) + (call primal) + [[x (df:dx primal)]])) + + (d/dual? x) + (let [[px tx] (d/primal-tangent-pair x) + primal (call px) + tangent (g/* (df:dx px) tx)] + (d/bundle-element primal tangent (d/tag x))) + + :else (f x))))) + +(defn lift-2 + "Given: + + - some binary function `f` + - a function `df:dx` that computes the derivative of `f` with respect to its + single argument + - a function `df:dy`, similar to `df:dx` for the second arg + + Returns a new binary function that operates on both the original type of + `f`, [[TapeCell]] and [[emmy.dual/Dual]] instances. + + NOTE: `df:dx` and `df:dy` have to ALREADY be able to handle [[TapeCell]] + and [[emmy.dual/Dual]] instances. The best way to accomplish this is + by building `df:dx` and `df:dy` out of already-lifted functions, and declaring + them by forward reference if you need to." + ([f] + (let [df:dx (f :dfdx) + df:dy (f :dfdy)] + (if (and df:dx df:dy) + (lift-2 f df:dx df:dy) + (u/illegal + "No df:dx, df:dy supplied for `f` or registered generically.")))) + ([f df:dx df:dy] + (fn call [x y] + (letfn [(operate-forward [tag] + (let [[xe dx] (d/primal-tangent-pair x tag) + [ye dy] (d/primal-tangent-pair y tag) + primal (call xe ye) + tangent (g/+ (if (g/numeric-zero? dx) + dx + (g/* (df:dx xe ye) dx)) + (if (g/numeric-zero? dy) + dy + (g/* (df:dy xe ye) dy)))] + (d/bundle-element primal tangent tag))) + + (operate-reverse [tag] + (let [primal-x (t/tape-primal x tag) + primal-y (t/tape-primal y tag) + partial-x (if (and (t/tape? x) (= tag (t/tape-tag x))) + [[x (df:dx primal-x primal-y)]] + []) + partial-y (if (and (t/tape? y) (= tag (t/tape-tag y))) + [[y (df:dy primal-x primal-y)]] + [])] + + (t/make tag + (call primal-x primal-y) + (into partial-x partial-y))))] + (if-let [[tag dx] (tag+perturbation x y)] + (cond (t/tape? dx) (operate-reverse tag) + (d/dual? dx) (operate-forward tag) + :else + (u/illegal "Non-tape or dual perturbation!")) + (f x y)))))) + +(defn lift-n + "Given: + + - some function `f` that can handle 0, 1 or 2 arguments + - `df:dx`, a fn that returns the derivative wrt the single arg in the unary case + - `df:dx1` and `df:dx2`, fns that return the derivative with respect to the + first and second args in the binary case + + Returns a new any-arity function that operates on both the original type of + `f`, [[TapeCell]] and [[emmy.dual/Dual]] instances. + + NOTE: The n-ary case of `f` is populated by nested calls to the binary case. + That means that this is NOT an appropriate lifting method for an n-ary + function that isn't built out of associative binary calls. If you need this + ability, please file an issue at the [emmy issue + tracker](https://github.com/mentat-collective/emmy/issues)." + [f df:dx df:dx1 df:dx2] + (let [f1 (lift-1 f df:dx) + f2 (lift-2 f df:dx1 df:dx2)] + (fn call + ([] (f)) + ([x] (f1 x)) + ([x y] (f2 x y)) + ([x y & more] + (reduce call (call x y) more))))) + +;; ## Generic Method Installation +;; +;; Armed with [[lift-1]] and [[lift-2]], we can install [[TapeCell]] into +;; the Emmy generic arithmetic system. + +(defn- defunary + "Given: + + - a generic unary multimethod `generic-op` + - optionally, a corresponding single-arity lifted function + `differential-op` (defaults to `(lift-1 generic-op)`) + + installs an appropriate unary implementation of `generic-op` for `::tape` and + `:emmy.dual/dual` instances." + ([generic-op] + (defunary generic-op (lift-1 generic-op))) + ([generic-op differential-op] + (defmethod generic-op [::d/dual] [a] (differential-op a)) + (defmethod generic-op [::t/tape] [a] (differential-op a)))) + +(defn- defbinary + "Given: + + - a generic binary multimethod `generic-op` + - optionally, a corresponding 2-arity lifted function + `differential-op` (defaults to `(lift-2 generic-op)`) + + installs an appropriate binary implementation of `generic-op` between + `::t/tape`, `:emmy.dual/dual` and `:emmy.value/scalar` instances." + ([generic-op] + (defbinary generic-op (lift-2 generic-op))) + ([generic-op differential-op] + (doseq [signature [[::t/tape ::t/tape] + [::d/dual ::d/dual] + [::t/tape ::d/dual] + [::d/dual ::t/tape] + [::v/scalar ::t/tape] + [::v/scalar ::d/dual] + [::t/tape ::v/scalar] + [::d/dual ::v/scalar]]] + (defmethod generic-op signature [a b] (differential-op a b))))) + +(defn ^:no-doc by-primal + "Given some unary or binary function `f`, returns an augmented `f` that acts on + the primal entries of any perturbed arguments encountered, irrespective of + tag." + [f] + (fn + ([x] (f (deep-primal x))) + ([x y] (f (deep-primal x) + (deep-primal y))))) + +(defbinary g/add) +(defunary g/negate) +(defbinary g/sub) + +(let [mul (lift-2 g/mul)] + (defbinary g/mul mul) + (defbinary g/dot-product mul)) +(defbinary g/expt) + +(defunary g/square) +(defunary g/cube) + +(defunary g/invert) +(defbinary g/div) + +(defunary g/abs + (fn [x] + (let [f (deep-primal x) + func (cond (< f 0) (lift-1 g/negate (fn [_] -1)) + (> f 0) (lift-1 identity (fn [_] 1)) + (= f 0) (u/illegal "Derivative of g/abs undefined at zero") + :else (u/illegal (str "error! derivative of g/abs at" x)))] + (func x)))) + +(defn- discont-at-integers [f dfdx] + (let [f (lift-1 f (fn [_] dfdx)) + f-name (g/freeze f)] + (fn [x] + (if (v/integral? (deep-primal x)) + (u/illegal + (str "Derivative of g/" f-name " undefined at integral points.")) + (f x))))) + +(defunary g/floor + (discont-at-integers g/floor 0)) + +(defunary g/ceiling + (discont-at-integers g/ceiling 0)) + +(defunary g/integer-part + (discont-at-integers g/integer-part 0)) + +(defunary g/fractional-part + (discont-at-integers g/fractional-part 1)) + +(let [div (lift-2 g/div)] + (defbinary g/solve-linear (fn [l r] (div r l))) + (defbinary g/solve-linear-right div)) + +(defunary g/sqrt) +(defunary g/log) +(defunary g/exp) + +(defunary g/cos) +(defunary g/sin) +(defunary g/tan) +(defunary g/cot) +(defunary g/sec) +(defunary g/csc) + +(defunary g/atan) +(defbinary g/atan) +(defunary g/asin) +(defunary g/acos) +(defunary g/acot) +(defunary g/asec) +(defunary g/acsc) + +(defunary g/cosh) +(defunary g/sinh) +(defunary g/tanh) +(defunary g/sech) +(defunary g/coth) +(defunary g/csch) + +(defunary g/acosh) +(defunary g/asinh) +(defunary g/atanh) +(defunary g/acoth) +(defunary g/asech) +(defunary g/acsch) + +(defunary g/sinc) +(defunary g/sinhc) +(defunary g/tanc) +(defunary g/tanhc) + +;; Non-differentiable generic operations + +(defbinary v/= (by-primal v/=)) +(defunary g/zero? + (let [zero-p? (by-primal g/zero?)] + (fn [dx] + (if (t/tape? dx) + (zero-p? dx) + (let [[p t] (d/primal-tangent-pair dx)] + (and (g/zero? p) + (g/zero? t))))))) + +(defunary g/one? + (let [one-p? (by-primal g/one?)] + (fn [dx] + (if (t/tape? dx) + (one-p? dx) + (d/one? dx))))) + +(defunary g/identity? + (let [identity-p? (by-primal g/identity?)] + (fn [dx] + (if (t/tape? dx) + (identity-p? dx) + (d/identity? dx))))) + +(defunary g/negative? (by-primal g/negative?)) +(defunary g/infinite? (by-primal g/infinite?)) diff --git a/src/emmy/calculus/derivative.cljc b/src/emmy/calculus/derivative.cljc index 334854a1..53c1a742 100644 --- a/src/emmy/calculus/derivative.cljc +++ b/src/emmy/calculus/derivative.cljc @@ -7,7 +7,8 @@ "This namespace implements a number of differential operators like [[D]], and the machinery to apply [[D]] to various structures." (:refer-clojure :exclude [partial]) - (:require [emmy.dual :as d] + (:require [emmy.autodiff] + [emmy.dual :as d] [emmy.expression :as x] [emmy.function :as f] [emmy.generic :as g] @@ -25,26 +26,6 @@ ;; in [[emmy.dual]] and declare an interface for taking ;; derivatives. -(defn derivative - "Returns a single-argument function of that, when called with an argument `x`, - returns the derivative of `f` at `x` using forward-mode automatic - differentiation. - - For numerical differentiation, - see [[emmy.numerical.derivative/D-numeric]]. - - `f` must be built out of generic operations that know how to - handle [[emmy.dual/Dual]] inputs in addition to any types that - a normal `(f x)` call would present. This restriction does _not_ apply to - operations like putting `x` into a container or destructuring; just primitive - function calls." - [f] - (fn [x] - (let [tag (d/fresh-tag) - lifted (d/bundle-element x 1 tag)] - (-> (d/with-active-tag tag f [lifted]) - (d/extract-tangent tag d/FORWARD-MODE))))) - ;; The result of applying the derivative `(D f)` of a multivariable function `f` ;; to a sequence of `args` is a structure of the same shape as `args` with all ;; orientations flipped. (For a partial derivative like `((partial 0 1) f)` the @@ -55,9 +36,9 @@ ;; ;; To generate the result: ;; -;; - For a single non-structural argument, return `(derivative f)` +;; - For a single non-structural argument, return `(d/derivative f)` ;; - else, bundle up all arguments into a single [[s/Structure]] instance `xs` -;; - Generate `xs'` by replacing each entry in `xs` with `((derivative f') +;; - Generate `xs'` by replacing each entry in `xs` with `((d/derivative f') ;; entry)`, where `f'` is a function of ONLY that entry that ;; calls `(f (assoc-in xs path entry))`. In other words, replace each entry ;; with the result of the partial derivative of `f` at only that entry. @@ -83,7 +64,7 @@ (if (v/scalar? entry) (letfn [(f-entry [x] (f (assoc-in structure path x)))] - ((derivative f-entry) entry)) + ((d/derivative f-entry) entry)) (u/illegal (str "non-numerical entry " entry " at path " path @@ -143,9 +124,9 @@ ;; non-empty selectors are only allowed for functions that receive ;; a structural argument. This case passes that single, - ;; non-structural argument on to `(derivative f)`. + ;; non-structural argument on to `(d/derivative f)`. (empty? selectors) - ((derivative f) input) + ((d/derivative f) input) ;; Any attempt to index (via non-empty selectors) into a ;; non-structural argument will throw. diff --git a/src/emmy/collection.cljc b/src/emmy/collection.cljc index 85698e2c..f3b6db1a 100644 --- a/src/emmy/collection.cljc +++ b/src/emmy/collection.cljc @@ -57,7 +57,6 @@ ;; perturbed. [[d/replace-tag]] and [[d/extract-tangent]] pass the buck down ;; the vector's elements. d/IPerturbed - (perturbed? [v] (boolean (some d/perturbed? v))) (replace-tag [v old new] (mapv #(d/replace-tag % old new) v)) (extract-tangent [v tag mode] (mapv #(d/extract-tangent % tag mode) v))) @@ -84,7 +83,6 @@ (kind [xs] (type xs)) d/IPerturbed - (perturbed? [_] false) (replace-tag [xs old new] (map #(d/replace-tag % old new) xs)) (extract-tangent [xs tag mode] (map #(d/extract-tangent % tag mode) xs)))) @@ -178,8 +176,7 @@ {:arity (fn [_] [:between 1 2])} d/IPerturbed - {:perturbed? (fn [m] (boolean (some d/perturbed? (vals m)))) - :replace-tag (fn [m old new] (u/map-vals #(d/replace-tag % old new) m)) + {:replace-tag (fn [m old new] (u/map-vals #(d/replace-tag % old new) m)) :extract-tangent (fn [m tag mode] (if-let [t (:type m)] @@ -199,7 +196,6 @@ (arity [_] [:between 1 2]) d/IPerturbed - (perturbed? [m] (boolean (some d/perturbed? (vals m)))) (replace-tag [m old new] (u/map-vals #(d/replace-tag % old new) m)) (extract-tangent [m tag mode] (if-let [t (:type m)] diff --git a/src/emmy/dual.cljc b/src/emmy/dual.cljc index 67fe8892..3c16bf56 100644 --- a/src/emmy/dual.cljc +++ b/src/emmy/dual.cljc @@ -7,19 +7,18 @@ "This namespace contains an implementation of [[Dual]], a type that forms the basis for the forward-mode automatic differentiation implementation in emmy. - See [[emmy.calculus.derivative]] for a fleshed-out derivative - implementation using [[Dual]]." + See [[emmy.calculus.derivative]] for a fleshed-out derivative implementation + using [[Dual]]." (:refer-clojure :exclude [compare]) (:require [emmy.generic :as g] [emmy.value :as v])) -;; ## Differentials, Dual Numbers and Automatic Differentiation +;; ## Dual Numbers and Automatic Differentiation ;; -;; This namespace develops an implementation of a type called [[Dual]]. -;; A [[Dual]] is a generalization of a type called a ["dual +;; This namespace develops an implementation a ["dual ;; number"](https://en.wikipedia.org/wiki/Dual_number). ;; -;; As we'll discuss, passing these numbers as arguments to some function $f$ +;; As we'll discuss, passing dual numbers as arguments to some function $f$ ;; built out of the [[emmy.generic]] operators allows us to build up the ;; _derivative_ of $f$ in parallel to our evaluation of $f$. Complex programs ;; are built out of simple pieces that we know how to evaluate; we can build up @@ -28,17 +27,21 @@ ;; ;; ### Forward-Mode Automatic Differentiation ;; -;; For many scientific computing applications, it's valuable be able to generate -;; a "derivative" of a function; given some tiny increment in the inputs, what -;; tiny increment will the function produce in the output values? +;; For many scientific computing applications, it's valuable to be able to +;; generate a "derivative" of a function; given some tiny increment in the +;; inputs, what tiny increment will the function produce in the output values? ;; ;; we know how to take derivatives of many of the generic functions exposed by -;; Emmy, like [[+]], [[*]], [[emmy.generic/sin]] and friends. It turns out that -;; we can take the derivatives of large, complicated functions by combining the -;; derivatives of these smaller functions using the [chain +;; Emmy, like [[emmy.generic/+]], [[emmy.generic/*]], [[emmy.generic/sin]] and +;; friends. It turns out that we can take the derivatives of large, complicated +;; functions by combining the derivatives of these smaller functions using +;; the [chain ;; rule]((https://en.wikipedia.org/wiki/Automatic_differentiation#The_chain_rule,_forward_and_reverse_accumulation)) ;; as a clever bookkeeping device. ;; +;; NOTE the two flavors are forward and reverse mode. First we'll do forward, +;; then reverse. +;; ;; The technique of evaluating a function and its derivative in parallel is ;; called "forward-mode [Automatic ;; Differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation)". @@ -50,8 +53,6 @@ ;; page](https://cljdoc.org/d/org.mentat/emmy/CURRENT/doc/calculus/automatic-differentiation) ;; for "how do I use this?"-style questions. ;; -;; > NOTE: The other flavor of automatic differentiation (AD) is "reverse-mode -;; > AD". See [[emmy.tape]] for an implementation of this style, coming soon! ;; ;; ### Dual Numbers and AD ;; @@ -205,8 +206,8 @@ (derivative (derivative f))) -;; But this guess hits one of many subtle problems with the implementation of -;; forward-mode AD. The double-call to `derivative` will expand out to this: +;; But this guess hits a subtle problem with the implementation of forward-mode +;; AD. The double-call to `derivative` will expand out like so: (comment (fn [x] @@ -236,9 +237,10 @@ ;; ;; The solution is to introduce a new $\varepsilon$ for every level, and allow ;; different $\varepsilon$ instances to multiply without annihilating. Each -;; $\varepsilon$ is called a "tag". [[Dual]] (implemented below) is a -;; generalized dual number that can track many tags at once, allowing nested -;; derivatives like the one described above to work. +;; $\varepsilon$ is called a "tag". By allowing dual numbers to contain dual +;; numbers with different tags in the primal and tangent slots, and carefully +;; managing which tag stays at the top level, we can allow nested derivatives +;; like the one described above to work. ;; ;; This implies that `extract-tangent` needs to take a tag, to determine _which_ ;; tangent to extract: @@ -251,14 +253,14 @@ (extract-tangent tag)))))) ;; This is close to the final form you'll find -;; at [[emmy.calculus.derivative/derivative]]. +;; at [[derivative]]. ;; ;; ### What Return Values are Allowed? ;; ;; Before we discuss the implementation of dual -;; numbers (called [[Dual]]), [[emmy.tape/lift-1]], [[emmy.tape/lift-2]] -;; and the rest of the machinery that makes this all possible; what sorts of -;; objects is `f` allowed to return? +;; numbers, [[emmy.tape/lift-1]], [[emmy.tape/lift-2]] and the rest of the +;; machinery that makes this all possible; what sorts of objects is `f` allowed +;; to return? ;; ;; The dual number approach is beautiful because we can bring to bear all sorts ;; of operations in Clojure that never even _see_ dual numbers. For example, @@ -290,9 +292,9 @@ (g (+ x offset)))))) ;; `(derivative offset-fn)` here returns a function! Manzyuk et al. 2019 makes -;; the reasonable claim that, if `(f x)` returns a function, then `(derivative -;; f)` should treat `f` as a multi-argument function with its first argument -;; curried. +;; the reasonable claim that, if `(offset-fn x)` returns a function, +;; then `(derivative offset-fn)` should treat `offset-fn` as a multi-argument +;; function with its first argument curried. ;; ;; Let's say `f` takes a number `x` and returns a function `g` that maps number ;; => number. `(((derivative f) x) y)` should act just like the partial @@ -326,11 +328,6 @@ ;; that we'll use later: (defprotocol IPerturbed - (perturbed? [this] - "Returns true if the supplied object has some known non-zero tangent to be - extracted via [[extract-tangent]], false otherwise. (Return `false` by - default if you can't detect a perturbation.)") - (replace-tag [this old-tag new-tag] "If `this` is perturbed, Returns a similar object with the perturbation modified by replacing any appearance of `old-tag` with `new-tag`. Else, @@ -340,28 +337,50 @@ "If `this` is perturbed, return the tangent component paired with the supplied tag. Else, returns `([[emmy.value/zero-like]] this)`.") - (extract-id [this id])) + (extract-id [this id] + "Given some ")) + +(defrecord Completed [v->partial] + IPerturbed + ;; NOTE that it's a problem that `replace-tag` is called on [[Completed]] + ;; instances now. In a future refactor I want `get` calls out of + ;; a [[Completed]] map to occur before tag replacement needs to happen. + (replace-tag [_ old new] + (Completed. + (replace-tag v->partial old new))) + + ;; This should never be called; it would be that a [[Completed]] instance has + ;; escaped from a derivative call. + (extract-tangent [_ _ _] (assert "Impossible!")) + (extract-id [_ id] (get v->partial id 0))) (def FORWARD-MODE ::forward) (def REVERSE-MODE ::reverse) +(def REVERSE-EMPTY (->Completed {})) ;; `replace-tag` exists to handle subtle bugs that can arise in the case of ;; functional return values. See the "Amazing Bug" sections -;; in [[emmy.calculus.derivative-test]] for detailed examples on how this -;; might bite you. +;; in [[emmy.calculus.derivative-test]] for detailed examples on how this might +;; bite you. ;; ;; The default implementations are straightforward, and match the docstrings: (extend-protocol IPerturbed nil - (perturbed? [_] false) (replace-tag [_ _ _] nil) - (extract-tangent [_ _ _] 0) + (extract-id [_ _] 0) + (extract-tangent [_ _ mode] + (if (= mode FORWARD-MODE) + 0 + REVERSE-EMPTY)) #?(:clj Object :cljs default) - (perturbed? [_] false) (replace-tag [this _ _] this) - (extract-tangent [this _ _] (g/zero-like this))) + (extract-id [_ _] 0) + (extract-tangent [this _ mode] + (if (= mode FORWARD-MODE) + (g/zero-like this) + REVERSE-EMPTY))) ;; ## Dual Implementation ;; @@ -392,15 +411,15 @@ (deftype Dual [tag primal tangent] IPerturbed - (perturbed? [_] true) - (replace-tag [this old new] (if (= old tag) (Dual. new primal tangent) this)) - (extract-tangent [_ t _] - (if (= t tag) tangent 0)) + (extract-tangent [_ t mode] + (cond (not= mode FORWARD-MODE) REVERSE-EMPTY + (= t tag) tangent + :else 0)) v/IKind (kind [_] ::dual) @@ -412,9 +431,9 @@ #?(:clj (equals [a b] (equiv a b))) #?(:cljs (valueOf [_] (.valueOf primal))) (toString [_] - (str "#emmy.tape.Dual" - {:tag tag - :primal primal + (str "#emmy.dual.Dual" + {:tag tag + :primal primal :tangent tangent})) #?@(:clj @@ -660,6 +679,27 @@ (primal a) (primal b))) +;; ## Derivative + +(defn derivative + "Returns a single-argument function of that, when called with an argument `x`, + returns the derivative of `f` at `x` using forward-mode automatic + differentiation. + + For numerical differentiation, + see [[emmy.numerical.derivative/D-numeric]]. + + `f` must be built out of generic operations that know how to handle [[Dual]] + inputs in addition to any types that a normal `(f x)` call would present. This + restriction does _not_ apply to operations like putting `x` into a container + or destructuring; just primitive function calls." + [f] + (fn [x] + (let [tag (fresh-tag) + lifted (bundle-element x 1 tag)] + (-> (with-active-tag tag f [lifted]) + (extract-tangent tag FORWARD-MODE))))) + ;; ## Chain Rule and Lifted Functions ;; ;; For the rest of the story, please see the implementations diff --git a/src/emmy/env.cljc b/src/emmy/env.cljc index 8a77213b..5e45c046 100644 --- a/src/emmy/env.cljc +++ b/src/emmy/env.cljc @@ -45,6 +45,7 @@ [emmy.calculus.vector-calculus] [emmy.calculus.vector-field] [emmy.complex] + [emmy.dual] [emmy.expression] [emmy.expression.render :as render] [emmy.function :as f] @@ -404,8 +405,7 @@ Riemann-curvature Riemann Ricci torsion-vector torsion curvature-components] - [emmy.calculus.derivative - derivative D D-as-matrix taylor-series] + [emmy.calculus.derivative D D-as-matrix taylor-series] [emmy.calculus.form-field form-field? nform-field? oneform-field? @@ -502,6 +502,8 @@ basis-components->vector-field vector-field->basis-components coordinatize evolution] + + [emmy.dual derivative] ;; Special Relativity diff --git a/src/emmy/matrix.cljc b/src/emmy/matrix.cljc index fcdbbb52..b5850eba 100644 --- a/src/emmy/matrix.cljc +++ b/src/emmy/matrix.cljc @@ -35,7 +35,6 @@ :else ::matrix)) d/IPerturbed - (perturbed? [_] (boolean (core/some d/perturbed? v))) (replace-tag [M old new] (fmap #(d/replace-tag % old new) M)) (extract-tangent [M tag mode] (fmap #(d/extract-tangent % tag mode) M)) diff --git a/src/emmy/operator.cljc b/src/emmy/operator.cljc index d25daa83..014d606d 100644 --- a/src/emmy/operator.cljc +++ b/src/emmy/operator.cljc @@ -33,7 +33,6 @@ (arity [_] arity) d/IPerturbed - (perturbed? [_] false) (replace-tag [_ old new] (Operator. (d/replace-tag o old new) arity name context m)) (extract-tangent [_ tag mode] diff --git a/src/emmy/polynomial.cljc b/src/emmy/polynomial.cljc index e4ab83f5..13930837 100644 --- a/src/emmy/polynomial.cljc +++ b/src/emmy/polynomial.cljc @@ -125,10 +125,6 @@ (arity [_] [:between 0 arity]) sd/IPerturbed - (perturbed? [_] - (let [coefs (map i/coefficient terms)] - (boolean (some sd/perturbed? coefs)))) - (replace-tag [this old new] (map-coefficients #(sd/replace-tag % old new) this)) diff --git a/src/emmy/quaternion.cljc b/src/emmy/quaternion.cljc index 17fa976b..322fa4ed 100644 --- a/src/emmy/quaternion.cljc +++ b/src/emmy/quaternion.cljc @@ -110,12 +110,6 @@ (arity [this] (arity this)) d/IPerturbed - (perturbed? [_] - (or (d/perturbed? r) - (d/perturbed? i) - (d/perturbed? j) - (d/perturbed? k))) - (replace-tag [_ old new] (Quaternion. (d/replace-tag r old new) diff --git a/src/emmy/series.cljc b/src/emmy/series.cljc index 10559332..ea04c31f 100644 --- a/src/emmy/series.cljc +++ b/src/emmy/series.cljc @@ -39,7 +39,6 @@ (arity [_] (f/arity (first xs))) d/IPerturbed - (perturbed? [_] false) (replace-tag [s old new] (fmap #(d/replace-tag % old new) s)) (extract-tangent [s tag mode] (fmap #(d/extract-tangent % tag mode) s)) @@ -196,7 +195,6 @@ (arity [_] [:exactly 1]) d/IPerturbed - (perturbed? [_] false) (replace-tag [s old new] (fmap #(d/replace-tag % old new) s)) (extract-tangent [s tag mode] (fmap #(d/extract-tangent % tag mode) s)) diff --git a/src/emmy/structure.cljc b/src/emmy/structure.cljc index 8ef21c15..fba9d26c 100644 --- a/src/emmy/structure.cljc +++ b/src/emmy/structure.cljc @@ -82,7 +82,6 @@ (f/seq-arity v)) d/IPerturbed - (perturbed? [_] (boolean (some d/perturbed? v))) (replace-tag [s old new] (mapr #(d/replace-tag % old new) s)) (extract-tangent [s tag mode] (mapr #(d/extract-tangent % tag mode) s)) diff --git a/src/emmy/tape.cljc b/src/emmy/tape.cljc index 2d565a5a..b3ab6c8a 100644 --- a/src/emmy/tape.cljc +++ b/src/emmy/tape.cljc @@ -132,8 +132,6 @@ (kind [_] ::tape) d/IPerturbed - (perturbed? [_] true) - (replace-tag [this old new] (if (= old tag) (TapeCell. new id primal in->partial) @@ -307,66 +305,6 @@ :primal (.-primal t) :in->partial (.-in->partial t)}) -(defn tag-of - "More permissive version of [[tape-tag]] that returns `nil` when passed a - non-perturbation." - [x] - (cond (tape? x) (tape-tag x) - (d/dual? x) (d/tag x) - :else nil)) - -(defn inner-tag - "Given any number of `tags`, returns the tag most recently bound - via [[with-active-tag]] (i.e., the tag connected with the _innermost_ call - to [[with-active-tag]]). - - If none of the tags are bound, returns `(apply max tags)`." - [& tags] - (or (some (apply hash-set tags) - d/*active-tags*) - (apply max tags))) - -(defn tag+perturbation - "Given any number of `dxs`, returns a pair of the form - - [ ] - - containing the tag and instance of [[emmy.dual/Dual]] or [[TapeCell]] - associated with the inner-most call to [[with-active-tag]] in the current call - stack. - - If none of `dxs` has an active tag, returns `nil`." - ([& dxs] - (let [xform (map - (fn [dx] - (when-let [t (tag-of dx)] - [t dx]))) - m (into {} xform dxs)] - (when (seq m) - (let [tag (apply inner-tag (keys m))] - [tag (m tag)]))))) - -(defn primal-of - "More permissive version of [[tape-primal]] that returns `v` when passed a - non-perturbation." - ([v] - (primal-of v (tag-of v))) - ([v tag] - (cond (tape? v) (tape-primal v tag) - (d/dual? v) (d/primal v tag) - :else v))) - -(defn deep-primal - "Version of [[tape-primal]] that will descend recursively into any perturbation - instance returned by [[tape-primal]] or [[emmy.dual/primal]] until - encountering a non-perturbation. - - Given a non-perturbation, acts as identity." - ([v] - (cond (tape? v) (recur (tape-primal v)) - (d/dual? v) (recur (d/primal v)) - :else v))) - ;; ### Comparison, Control Flow ;; ;; Functions like `=`, `<` and friends don't have derivatives; instead, they're @@ -490,11 +428,6 @@ ;; escaped from a derivative call. These are meant to be an internal ;; implementation detail only. (extract-tangent [_ _ _] - (assert "Impossible!")) - - ;; This is called on arguments to literal functions to check if a derivative - ;; needs to be taken. This should never happen with a [[Completed]] instance! - (perturbed? [_] (assert "Impossible!"))) (defn process [sensitivities tape] @@ -715,335 +648,7 @@ selectors) (matrix/seq-> (cons x more))))))) -;; ## Lifted Functions - -;; [[lift-1]] and [[lift-2]] "lift", or augment, unary or binary functions with -;; the ability to handle [[emmy.dual/Dual]] and [[TapeCell]] instances -;; in addition to whatever other types they previously supported. -;; -;; Forward-mode support for [[emmy.dual/Dual]] is an implementation of -;; the single and multivariable Taylor series expansion methods discussed at the -;; beginning of [[emmy.dual]]. -;; -;; To support reverse-mode automatic differentiation, When a unary or binary -;; function `f` encounters a [[TapeCell]] `x` (and `y` in the binary case) it -;; needs to return a new [[TapeCell]] with: -;; -;; - the same tag -;; - a fresh, unique ID -;; - a primal value `(f x)` (or `(f x y)`) - -;; - a map of each input to the partial of `f` with respect to that input. -;; So, `{x ((D f) x)}` in the unary case, and -;; -;; ```clojure -;; {x (((partial 0) f) x y) -;; y (((partial 1) f) x y)} -;; ```` -;; -;; in the binary case. - -;; There is a subtlety here, noted in the docstrings below. [[lift-1]] -;; and [[lift-2]] really are able to lift functions like [[clojure.core/+]] that -;; can't accept [[emmy.dual/Dual]] and [[TapeCell]]s. But the -;; first-order derivatives that you have to supply _do_ have to be able to take -;; instances of these types. -;; -;; This is because, for example, the [[emmy.dual/tangent]] of [[emmy.dual/Dual]] -;; might still be a [[emmy.dual/Dual]], and will hit the first-order derivative via the -;; chain rule. -;; -;; Magically this will all Just Work if you pass an already-lifted function, or -;; a function built out of already-lifted components, as `df:dx` or `df:dy`. - -(defn lift-1 - "Given: - - - some unary function `f` - - a function `df:dx` that computes the derivative of `f` with respect to its - single argument - - Returns a new unary function that operates on both the original type of - `f`, [[TapeCell]] and [[emmy.dual/Dual]] instances. - - If called without `df:dx`, `df:dx` defaults to `(f :dfdx)`; this will return - the derivative registered to a generic function defined - with [[emmy.util.def/defgeneric]]. - - NOTE: `df:dx` has to ALREADY be able to handle [[TapeCell]] - and [[emmy.dual/Dual]] instances. The best way to accomplish this is - by building `df:dx` out of already-lifted functions, and declaring them by - forward reference if you need to." - ([f] - (if-let [df:dx (f :dfdx)] - (lift-1 f df:dx) - (u/illegal - "No df:dx supplied for `f` or registered generically."))) - ([f df:dx] - (fn call [x] - (cond (tape? x) - (let [primal (tape-primal x)] - (make (tape-tag x) - (call primal) - [[x (df:dx primal)]])) - - (d/dual? x) - (let [[px tx] (d/primal-tangent-pair x) - primal (call px) - tangent (g/* (df:dx px) tx)] - (d/bundle-element primal tangent (d/tag x))) - - :else (f x))))) - -(defn lift-2 - "Given: - - - some binary function `f` - - a function `df:dx` that computes the derivative of `f` with respect to its - single argument - - a function `df:dy`, similar to `df:dx` for the second arg - - Returns a new binary function that operates on both the original type of - `f`, [[TapeCell]] and [[emmy.dual/Dual]] instances. - - NOTE: `df:dx` and `df:dy` have to ALREADY be able to handle [[TapeCell]] - and [[emmy.dual/Dual]] instances. The best way to accomplish this is - by building `df:dx` and `df:dy` out of already-lifted functions, and declaring - them by forward reference if you need to." - ([f] - (let [df:dx (f :dfdx) - df:dy (f :dfdy)] - (if (and df:dx df:dy) - (lift-2 f df:dx df:dy) - (u/illegal - "No df:dx, df:dy supplied for `f` or registered generically.")))) - ([f df:dx df:dy] - (fn call [x y] - (letfn [(operate-forward [tag] - (let [[xe dx] (d/primal-tangent-pair x tag) - [ye dy] (d/primal-tangent-pair y tag) - primal (call xe ye) - tangent (g/+ (if (g/numeric-zero? dx) - dx - (g/* (df:dx xe ye) dx)) - (if (g/numeric-zero? dy) - dy - (g/* (df:dy xe ye) dy)))] - (d/bundle-element primal tangent tag))) - - (operate-reverse [tag] - (let [primal-x (tape-primal x tag) - primal-y (tape-primal y tag) - partial-x (if (and (tape? x) (= tag (tape-tag x))) - [[x (df:dx primal-x primal-y)]] - []) - partial-y (if (and (tape? y) (= tag (tape-tag y))) - [[y (df:dy primal-x primal-y)]] - [])] - - (make tag - (call primal-x primal-y) - (into partial-x partial-y))))] - (if-let [[tag dx] (tag+perturbation x y)] - (cond (tape? dx) (operate-reverse tag) - (d/dual? dx) (operate-forward tag) - :else - (u/illegal "Non-tape or dual perturbation!")) - (f x y)))))) - -(defn lift-n - "Given: - - - some function `f` that can handle 0, 1 or 2 arguments - - `df:dx`, a fn that returns the derivative wrt the single arg in the unary case - - `df:dx1` and `df:dx2`, fns that return the derivative with respect to the - first and second args in the binary case - - Returns a new any-arity function that operates on both the original type of - `f`, [[TapeCell]] and [[emmy.dual/Dual]] instances. - - NOTE: The n-ary case of `f` is populated by nested calls to the binary case. - That means that this is NOT an appropriate lifting method for an n-ary - function that isn't built out of associative binary calls. If you need this - ability, please file an issue at the [emmy issue - tracker](https://github.com/mentat-collective/emmy/issues)." - [f df:dx df:dx1 df:dx2] - (let [f1 (lift-1 f df:dx) - f2 (lift-2 f df:dx1 df:dx2)] - (fn call - ([] (f)) - ([x] (f1 x)) - ([x y] (f2 x y)) - ([x y & more] - (reduce call (call x y) more))))) - -;; ## Generic Method Installation -;; -;; Armed with [[lift-1]] and [[lift-2]], we can install [[TapeCell]] into -;; the Emmy generic arithmetic system. - -(defn- defunary - "Given: - - - a generic unary multimethod `generic-op` - - optionally, a corresponding single-arity lifted function - `differential-op` (defaults to `(lift-1 generic-op)`) - - installs an appropriate unary implementation of `generic-op` for `::tape` and - `:emmy.dual/dual` instances." - ([generic-op] - (defunary generic-op (lift-1 generic-op))) - ([generic-op differential-op] - (defmethod generic-op [::d/dual] [a] (differential-op a)) - (defmethod generic-op [::tape] [a] (differential-op a)))) - -(defn- defbinary - "Given: - - - a generic binary multimethod `generic-op` - - optionally, a corresponding 2-arity lifted function - `differential-op` (defaults to `(lift-2 generic-op)`) - - installs an appropriate binary implementation of `generic-op` between - `::tape`, `::emmy.dual/dual` and `::v/scalar` instances." - ([generic-op] - (defbinary generic-op (lift-2 generic-op))) - ([generic-op differential-op] - (doseq [signature [[::tape ::tape] - [::d/dual ::d/dual] - [::tape ::d/dual] - [::d/dual ::tape] - [::v/scalar ::tape] - [::v/scalar ::d/dual] - [::tape ::v/scalar] - [::d/dual ::v/scalar]]] - (defmethod generic-op signature [a b] (differential-op a b))))) - -(defn ^:no-doc by-primal - "Given some unary or binary function `f`, returns an augmented `f` that acts on - the primal entries of any perturbed arguments encountered, irrespective of - tag." - [f] - (fn - ([x] (f (deep-primal x))) - ([x y] (f (deep-primal x) - (deep-primal y))))) - -(defbinary g/add) -(defunary g/negate) -(defbinary g/sub) - -(let [mul (lift-2 g/mul)] - (defbinary g/mul mul) - (defbinary g/dot-product mul)) -(defbinary g/expt) - -(defunary g/square) -(defunary g/cube) - -(defunary g/invert) -(defbinary g/div) - -(defunary g/abs - (fn [x] - (let [f (deep-primal x) - func (cond (< f 0) (lift-1 g/negate (fn [_] -1)) - (> f 0) (lift-1 identity (fn [_] 1)) - (= f 0) (u/illegal "Derivative of g/abs undefined at zero") - :else (u/illegal (str "error! derivative of g/abs at" x)))] - (func x)))) - -(defn- discont-at-integers [f dfdx] - (let [f (lift-1 f (fn [_] dfdx)) - f-name (g/freeze f)] - (fn [x] - (if (v/integral? (deep-primal x)) - (u/illegal - (str "Derivative of g/" f-name " undefined at integral points.")) - (f x))))) - -(defunary g/floor - (discont-at-integers g/floor 0)) - -(defunary g/ceiling - (discont-at-integers g/ceiling 0)) - -(defunary g/integer-part - (discont-at-integers g/integer-part 0)) - -(defunary g/fractional-part - (discont-at-integers g/fractional-part 1)) - -(let [div (lift-2 g/div)] - (defbinary g/solve-linear (fn [l r] (div r l))) - (defbinary g/solve-linear-right div)) - -(defunary g/sqrt) -(defunary g/log) -(defunary g/exp) - -(defunary g/cos) -(defunary g/sin) -(defunary g/tan) -(defunary g/cot) -(defunary g/sec) -(defunary g/csc) - -(defunary g/atan) -(defbinary g/atan) -(defunary g/asin) -(defunary g/acos) -(defunary g/acot) -(defunary g/asec) -(defunary g/acsc) - -(defunary g/cosh) -(defunary g/sinh) -(defunary g/tanh) -(defunary g/sech) -(defunary g/coth) -(defunary g/csch) - -(defunary g/acosh) -(defunary g/asinh) -(defunary g/atanh) -(defunary g/acoth) -(defunary g/asech) -(defunary g/acsch) - -(defunary g/sinc) -(defunary g/sinhc) -(defunary g/tanc) -(defunary g/tanhc) - -;; Non-differentiable generic operations - -(defbinary v/= (by-primal v/=)) -(defunary g/zero? - (let [zero-p? (by-primal g/zero?)] - (fn [dx] - (if (tape? dx) - (zero-p? dx) - (let [[p t] (d/primal-tangent-pair dx)] - (and (g/zero? p) - (g/zero? t))))))) - -(defunary g/one? - (let [one-p? (by-primal g/one?)] - (fn [dx] - (if (tape? dx) - (one-p? dx) - (d/one? dx))))) - -(defunary g/identity? - (let [identity-p? (by-primal g/identity?)] - (fn [dx] - (if (tape? dx) - (identity-p? dx) - (d/identity? dx))))) - -(defunary g/negative? (by-primal g/negative?)) -(defunary g/infinite? (by-primal g/infinite?)) +;; ## TapeCell Generics (defmethod g/zero-like [::tape] [_] 0) (defmethod g/one-like [::tape] [_] 1) diff --git a/test/emmy/collection_test.cljc b/test/emmy/collection_test.cljc index febcfdd5..76bae6c7 100644 --- a/test/emmy/collection_test.cljc +++ b/test/emmy/collection_test.cljc @@ -254,20 +254,6 @@ :two-prime (complex 2 0)})) " with unequal keys, we fail."))) - (checking "d/perturbed?" 100 - [m (gen/map gen/keyword sg/any-integral)] - (is (not (d/perturbed? m)) - "maps with no [[Differential]] aren't perturbed.") - - (let [diff (d/bundle-element 1 1 0)] - (is (d/perturbed? (assoc m :key diff)) - "adding a perturbed entry perturbs the map.") - - (is (d/perturbed? - {:outer-key - (assoc m :key diff)}) - "d/perturbed? descends into keys"))) - (let [m {:sin g/sin :cos g/cos} {D-sin :sin D-cos :cos} (D m)] (is (= {:sin ((D g/sin) 'x) diff --git a/test/emmy/dual_test.cljc b/test/emmy/dual_test.cljc index 34b5d86f..ed7d63cb 100644 --- a/test/emmy/dual_test.cljc +++ b/test/emmy/dual_test.cljc @@ -5,27 +5,17 @@ [clojure.test :refer [is deftest testing use-fixtures]] [clojure.test.check.generators :as gen] [com.gfredericks.test.chuck.clojure-test :refer [checking]] + [emmy.autodiff :as ad] [emmy.dual :as d] [emmy.generators :as sg] [emmy.generic :as g] [emmy.numerical.derivative :refer [D-numeric]] [emmy.simplify :refer [hermetic-simplify-fixture]] - [emmy.tape :as tape] [emmy.value :as v] [same.core :refer [ish? with-comparator]])) (use-fixtures :each hermetic-simplify-fixture) -(defn- derivative - "A bare-bones derivative implementation, used for testing the functionality made - available by [[emmy.dual/Dual]]. The real version lives - at [[emmy.calculus.derivative/derivative]]" - [f] - (let [tag (d/fresh-tag)] - (fn [x] - (-> (f (d/bundle-element x 1 tag)) - (d/extract-tangent tag d/FORWARD-MODE))))) - (defn nonzero [gen] (gen/fmap (fn [x] (if (= x 0) @@ -197,7 +187,7 @@ (g/simplify not-simple))) "simplify simplifies primal and tangent") - (is (= "#emmy.tape.Dual{:tag 0, :primal (expt x 4), :tangent (* 4 (expt x 3))}" + (is (= "#emmy.dual.Dual{:tag 0, :primal (expt x 4), :tangent (* 4 (expt x 3))}" (str (g/simplify not-simple))) "str representation properly simplifies."))))))) @@ -210,11 +200,11 @@ (g/* x g/square) (g/* x g/cube))] (g x))) - Df (derivative f)] - (is (= ((derivative (g/* identity g/square)) 10) + Df (d/derivative f)] + (is (= ((d/derivative (g/* identity g/square)) 10) (Df 10)) "providing 10 takes the x*g/square branch") - (is (= ((derivative (g/* identity g/cube)) 9) + (is (= ((d/derivative (g/* identity g/cube)) 9) (Df 9)) "providing 9 takes the x*g/cube branch")))) @@ -232,7 +222,7 @@ "d/with-active-tag calls its fn with the supplied args."))) (deftest dual-arithmetic-tests - (let [Df (derivative (fn [_x]))] + (let [Df (d/derivative (fn [_x]))] (checking "derivative of nil-valued function is always zero" 100 [x sg/real] (is (zero? (Df x))))) @@ -334,7 +324,7 @@ (deftest lifted-fn-tests (letfn [(breaks? [f x] (is (thrown? #?(:clj IllegalArgumentException :cljs js/Error) - ((derivative f) x))))] + ((d/derivative f) x))))] (checking "integer-discontinuous derivatives work" 100 [x sg/real] (if (v/integral? x) (do (breaks? g/floor x) @@ -342,17 +332,17 @@ (breaks? g/integer-part x) (breaks? g/fractional-part x)) - (do (is (zero? ((derivative g/floor) x))) - (is (zero? ((derivative g/ceiling) x))) - (is (zero? ((derivative g/integer-part) x))) - (is (g/one? ((derivative g/fractional-part) x))))))) + (do (is (zero? ((d/derivative g/floor) x))) + (is (zero? ((d/derivative g/ceiling) x))) + (is (zero? ((d/derivative g/integer-part) x))) + (is (g/one? ((d/derivative g/fractional-part) x))))))) (testing "lift-n" - (let [* (tape/lift-n g/* (fn [_] 1) (fn [_ y] y) (fn [x _] x)) - Df7 (derivative + (let [* (ad/lift-n g/* (fn [_] 1) (fn [_ y] y) (fn [x _] x)) + Df7 (d/derivative (fn x**7 [x] (* x x x x x x x))) - Df1 (derivative *) - Df0 (derivative (fn [_] (*)))] + Df1 (d/derivative *) + Df0 (d/derivative (fn [_] (*)))] (is (v/= '(* 7 (expt x 6)) (g/simplify (Df7 'x))) "functions created with lift-n can take many args (they reduce via the @@ -371,7 +361,7 @@ (g/+ (g/sin x) (g/expt x 2)) (g/+ (g/sin x) x) (g/log (g/abs x)))) - Df (derivative f) + Df (d/derivative f) Df-numeric (D-numeric f)] (with-comparator (v/within 1e-6) (checking "exercise some lifted fns" 100 @@ -385,10 +375,10 @@ "Does numeric match autodiff?")))))) (deftest sinc-etc-tests - (is (zero? ((derivative g/sinc) 0))) - (is (zero? ((derivative g/tanc) 0))) - (is (zero? ((derivative g/sinhc) 0))) - (is (zero? ((derivative g/tanhc) 0))) + (is (zero? ((d/derivative g/sinc) 0))) + (is (zero? ((d/derivative g/tanc) 0))) + (is (zero? ((d/derivative g/sinhc) 0))) + (is (zero? ((d/derivative g/tanhc) 0))) (letfn [(gen-double [min max] (gen/double* @@ -399,68 +389,68 @@ (with-comparator (v/within 1e-4) (checking "sinc" 100 [n (gen-double 1 50)] (is (ish? ((D-numeric g/sinc) n) - ((derivative g/sinc) n)))) + ((d/derivative g/sinc) n)))) ;; attempting to limit to a region where we avoid the infinities at ;; multiples of pi/2 (other than 0). (checking "tanc" 100 [n (gen-double 0.01 (- (/ Math/PI 2) 0.01))] (is (ish? ((D-numeric g/tanc) n) - ((derivative g/tanc) n)))) + ((d/derivative g/tanc) n)))) (checking "tanhc" 100 [n (gen-double 1 50)] (is (ish? ((D-numeric g/tanhc) n) - ((derivative g/tanhc) n))))) + ((d/derivative g/tanhc) n))))) (with-comparator (v/within 1e-4) (checking "sinhc" 100 [n (gen-double 1 10)] (is (ish? ((D-numeric g/sinhc) n) - ((derivative g/sinhc) n))))) + ((d/derivative g/sinhc) n))))) (with-comparator (v/within 1e-8) (checking "acot" 100 [n (gen-double 0.01 (- (/ Math/PI 2) 0.01))] (is (ish? ((D-numeric g/acot) n) - ((derivative g/acot) n)))) + ((d/derivative g/acot) n)))) (checking "asec" 100 [n (gen-double 3 100)] (is (ish? ((D-numeric g/asec) n) - ((derivative g/asec) n)))) + ((d/derivative g/asec) n)))) (checking "acsc" 100 [n (gen-double 3 100)] (is (ish? ((D-numeric g/acsc) n) - ((derivative g/acsc) n)))) + ((d/derivative g/acsc) n)))) (checking "sech" 100 [n (gen-double 3 100)] (is (ish? ((D-numeric g/sech) n) - ((derivative g/sech) n)))) + ((d/derivative g/sech) n)))) (checking "coth" 100 [n (gen-double 1 3)] (is (ish? ((D-numeric g/coth) n) - ((derivative g/coth) n)))) + ((d/derivative g/coth) n)))) (checking "csch" 100 [n (gen-double 0.5 10)] (is (ish? ((D-numeric g/csch) n) - ((derivative g/csch) n)))) + ((d/derivative g/csch) n)))) (checking "acosh" 100 [n (gen-double 2 10)] (is (ish? ((D-numeric g/acosh) n) - ((derivative g/acosh) n)))) + ((d/derivative g/acosh) n)))) (checking "asinh" 100 [n (gen-double 2 10)] (is (ish? ((D-numeric g/asinh) n) - ((derivative g/asinh) n)))) + ((d/derivative g/asinh) n)))) (checking "atanh" 100 [n (gen-double 0.1 0.9)] (is (ish? ((D-numeric g/atanh) n) - ((derivative g/atanh) n)))) + ((d/derivative g/atanh) n)))) (checking "acoth" 100 [n (gen-double 2 10)] (is (ish? ((D-numeric g/acoth) n) - ((derivative g/acoth) n)))) + ((d/derivative g/acoth) n)))) (checking "asech" 100 [n (gen-double 0.1 0.9)] (is (ish? ((D-numeric g/asech) n) - ((derivative g/asech) n)))) + ((d/derivative g/asech) n)))) (checking "acsch" 100 [n (gen-double 2 10)] (is (ish? ((D-numeric g/acsch) n) - ((derivative g/acsch) n))))))) + ((d/derivative g/acsch) n))))))) diff --git a/test/emmy/polynomial_test.cljc b/test/emmy/polynomial_test.cljc index 47c9984e..a2aa3870 100644 --- a/test/emmy/polynomial_test.cljc +++ b/test/emmy/polynomial_test.cljc @@ -53,11 +53,6 @@ ((D f) 'x)) "polynomial derivatives with respect to some coefficient work!"))) - (checking "perturbed?" 100 - [p (sg/polynomial :coeffs (sg/dual))] - (is (sd/perturbed? p) - "A polynomial with perturbed coefficients is perturbed.")) - (checking "polynomials are polynomial?, v/kind, misc others" 100 [p (sg/polynomial)] (is (p/polynomial? p)) diff --git a/test/emmy/tape_test.cljc b/test/emmy/tape_test.cljc index 16e26967..1bc19079 100644 --- a/test/emmy/tape_test.cljc +++ b/test/emmy/tape_test.cljc @@ -4,6 +4,7 @@ (:require #?(:clj [clojure.pprint :as pprint]) [clojure.test :refer [is deftest testing use-fixtures]] [clojure.test.check.generators :as gen] + [emmy.autodiff :as ad] [com.gfredericks.test.chuck.clojure-test :refer [checking]] [emmy.calculus.derivative :refer [D]] [emmy.dual :as d] @@ -167,7 +168,7 @@ "t/eq handles equality") (is (not (t/eq (t/make 0 n) (t/make 1 n))) - "t/eq is false for [[Differential]]s with diff tags")) + "t/eq is false for [[emmy.tape/Tape]]s with diff tags")) (checking "compare ignores tangent parts" 100 [l sg/real, r sg/real] @@ -214,10 +215,6 @@ (g/simplify tape)) "simplify simplifies all in->partial entries AND the primal "))) - (checking "d/perturbed?" 100 [tape (sg/tapecell gen/symbol)] - (is (d/perturbed? tape) - "all tags are perturbed?")) - (checking "d/extract-tangent" 100 [tag gen/nat tape (sg/tapecell gen/symbol)] (is (zero? (d/extract-tangent tape tag d/FORWARD-MODE)) @@ -261,21 +258,21 @@ (testing "tag-of" (checking "tag-of matches tape-tag for cells" 100 [tag gen/nat] (let [cell (t/make tag 1)] - (is (= (t/tag-of cell) + (is (= (ad/tag-of cell) (t/tape-tag cell)) "for tape cells, these should match")))) (testing "primal-of" (checking "for any other type primal-of == identity" 100 [x gen/any-equatable] - (is (= x (t/primal-of x)))) + (is (= x (ad/primal-of x)))) (checking "vs tape-primal" 100 [tape (sg/tapecell gen/symbol)] - (is (= (t/primal-of tape) + (is (= (ad/primal-of tape) (t/tape-primal tape)) "primal-of eq with and without tag") - (is (= (t/primal-of tape) - (t/primal-of tape (t/tape-tag tape))) + (is (= (ad/primal-of tape) + (ad/primal-of tape (t/tape-tag tape))) "primal-of eq with and without tag") (is (= (t/tape-primal tape) @@ -284,7 +281,7 @@ (checking "deep-primal returns nested primal" 100 [p gen/any-equatable] (let [cell (t/make 0 (t/make 1 p))] - (is (= p (t/deep-primal cell)) + (is (= p (ad/deep-primal cell)) "for tape cells, these should match")))) (deftest reverse-mode-tests @@ -323,7 +320,7 @@ (is (g/one? ((t/gradient g/fractional-part) x))))))) (testing "lift-n" - (let [* (t/lift-n g/* (fn [_] 1) (fn [_ y] y) (fn [x _] x)) + (let [* (ad/lift-n g/* (fn [_] 1) (fn [_ y] y) (fn [x _] x)) Df7 (t/gradient (fn x**7 [x] (* x x x x x x x))) Df1 (t/gradient *)