diff --git a/src/numericalnim/integrate.nim b/src/numericalnim/integrate.nim index c077347..6f68416 100644 --- a/src/numericalnim/integrate.nim +++ b/src/numericalnim/integrate.nim @@ -6,10 +6,13 @@ import arraymancer from ./interpolate import InterpolatorType, newHermiteSpline +# to annotate procedures with `{.genInterp.}` to generate `InterpolatorType` overloads +import private/macro_utils + ## # Integration ## This module implements various integration routines. ## It provides: -## +## ## ## Integrate discrete data: ## - `trapz`, `simpson`: works for any spacing between points. ## - `romberg`: requires equally spaced points and the number of points must be of the form 2^k + 1 ie 3, 5, 9, 17, 33, 65, 129 etc. @@ -27,7 +30,7 @@ runnableExamples: ## It also handles infinite integration limits. ## - `gaussQuad`: Fixed step size Gaussian quadrature. ## - `romberg`: Adaptive method based on Richardson Extrapolation. -## - `adaptiveSimpson`: Adaptive step size. +## - `adaptiveSimpson`: Adaptive step size. ## - `simpson`: Fixed step size. ## - `trapz`: Fixed step size. @@ -36,7 +39,7 @@ runnableExamples: proc f(x: float, ctx: NumContext[float, float]): float = exp(x) - + let a = 0.0 let b = Inf let integral = adaptiveGauss(f, a, b) @@ -74,10 +77,9 @@ type IntervalList[T; U; V] = object list: seq[IntervalType[T, U, V]] # contains all the intervals sorted from smallest to largest error - # N: #intervals proc trapz*[T](f: NumContextProc[T, float], xStart, xEnd: float, - N = 500, ctx: NumContext[T, float] = nil): T = + N = 500, ctx: NumContext[T, float] = nil): T {.genInterp.} = ## Calculate the integral of f using the trapezoidal rule. ## ## Input: @@ -174,7 +176,7 @@ proc cumtrapz*[T](f: NumContextProc[T, float], X: openArray[float], proc simpson*[T](f: NumContextProc[T, float], xStart, xEnd: float, - N = 500, ctx: NumContext[T, float] = nil): T = + N = 500, ctx: NumContext[T, float] = nil): T {.genInterp.} = ## Calculate the integral of f using Simpson's rule. ## ## Input: @@ -252,7 +254,7 @@ proc simpson*[T](Y: openArray[T], X: openArray[float]): T = result += alpha * ySorted[2*i + 2] + beta * ySorted[2*i + 1] + eta * ySorted[2*i] proc adaptiveSimpson*[T](f: NumContextProc[T, float], xStart, xEnd: float, - tol = 1e-8, ctx: NumContext[T, float] = nil): T = + tol = 1e-8, ctx: NumContext[T, float] = nil): T {.genInterp.} = ## Calculate the integral of f using an adaptive Simpson's rule. ## ## Input: @@ -284,7 +286,7 @@ proc adaptiveSimpson*[T](f: NumContextProc[T, float], xStart, xEnd: float, return left + right proc internal_adaptiveSimpson[T](f: NumContextProc[T, float], xStart, xEnd: float, - tol: float, ctx: NumContext[T, float], reused_points: array[3, T]): T = + tol: float, ctx: NumContext[T, float], reused_points: array[3, T]): T {.genInterp.} = let zero = reused_points[0] - reused_points[0] let dx1 = (xEnd - xStart) / 2 let dx2 = (xEnd - xStart) / 4 @@ -302,7 +304,7 @@ proc internal_adaptiveSimpson[T](f: NumContextProc[T, float], xStart, xEnd: floa return left + right proc adaptiveSimpson2*[T](f: NumContextProc[T, float], xStart, xEnd: float, - tol = 1e-8, ctx: NumContext[T, float] = nil): T = + tol = 1e-8, ctx: NumContext[T, float] = nil): T {.genInterp.} = ## Calculate the integral of f using an adaptive Simpson's rule. ## ## Input: @@ -399,7 +401,7 @@ proc cumsimpson*[T](f: NumContextProc[T, float], X: openArray[float], result = hermiteInterpolate(X, t, ys, dy) proc romberg*[T](f: NumContextProc[T, float], xStart, xEnd: float, - depth = 8, tol = 1e-8, ctx: NumContext[T, float] = nil): T = + depth = 8, tol = 1e-8, ctx: NumContext[T, float] = nil): T {.genInterp.} = ## Calculate the integral of f using Romberg Integration. ## ## Input: @@ -594,7 +596,7 @@ proc getGaussLegendreWeights(nPoints: int): tuple[nodes: seq[float], weights: se return gaussWeights[nPoints] proc gaussQuad*[T](f: NumContextProc[T, float], xStart, xEnd: float, - N = 100, nPoints = 7, ctx: NumContext[T, float] = nil): T = + N = 100, nPoints = 7, ctx: NumContext[T, float] = nil): T {.genInterp.} = ## Calculate the integral of f using Gaussian Quadrature. ## Has 20 different sets of weights, ranging from 1 to 20 function evaluations per subinterval. ## @@ -654,7 +656,7 @@ proc calcGaussKronrod[T; U](f: NumContextProc[T, U], xStart, xEnd: U, ctx: NumCo proc adaptiveGaussLocal*[T](f: NumContextProc[T, float], - xStart, xEnd: float, tol = 1e-8, ctx: NumContext[T, float] = nil): T = + xStart, xEnd: float, tol = 1e-8, ctx: NumContext[T, float] = nil): T {.genInterp.} = ## Calculate the integral of f using an locally adaptive Gauss-Kronrod Quadrature. ## ## Input: @@ -854,7 +856,7 @@ template adaptiveGaussImpl(): untyped {.dirty.} = totalValue += highValue proc adaptiveGauss*[T; U](f_in: NumContextProc[T, U], - xStart_in, xEnd_in: U, tol = 1e-8, initialPoints: openArray[U] = @[], maxintervals: int = 10000, ctx: NumContext[T, U] = nil): T = + xStart_in, xEnd_in: U, tol = 1e-8, initialPoints: openArray[U] = @[], maxintervals: int = 10000, ctx: NumContext[T, U] = nil): T {.genInterp.} = ## Calculate the integral of f using an globally adaptive Gauss-Kronrod Quadrature. Inf and -Inf can be used as integration limits. ## ## Input: @@ -909,7 +911,10 @@ proc cumGaussSpline*[T; U](f_in: NumContextProc[T, U], result = newHermiteSpline[T](xs, ys) proc cumGauss*[T](f_in: NumContextProc[T, float], - X: openArray[float], tol = 1e-8, initialPoints: openArray[float] = @[], maxintervals: int = 10000, ctx: NumContext[T, float] = nil): seq[T] = + X: openArray[float], tol = 1e-8, + initialPoints: openArray[float] = @[], + maxintervals: int = 10000, + ctx: NumContext[T, float] = nil): seq[T] {.genInterp.} = ## Calculate the cumulative integral of f using an globally adaptive Gauss-Kronrod Quadrature. ## Returns a sequence of values which is the cumulative integral of f at the points defined in X. ## Important: because of the much higher order of the Gauss-Kronrod quadrature (order 21) compared to the interpolating Hermite spline (order 3) you have to give it a large amount of initialPoints. diff --git a/src/numericalnim/private/macro_utils.nim b/src/numericalnim/private/macro_utils.nim new file mode 100644 index 0000000..f90039e --- /dev/null +++ b/src/numericalnim/private/macro_utils.nim @@ -0,0 +1,93 @@ +import std / macros +proc checkArgNumContext(fn: NimNode) = + ## Checks the first argument of the given proc is indeed a `NumContextProc` argument. + let params = fn.params + # FormalParams <- `.params` + # Ident "T" + # IdentDefs <- `params[1]` + # Sym "f" + # BracketExpr <- `params[1][1]` + # Sym "NumContextProc" <- `params[1][1][0]` + # Ident "T" + # Sym "float" + # Empty + expectKind params, nnkFormalParams + expectKind params[1], nnkIdentDefs + expectKind params[1][1], nnkBracketExpr + expectKind params[1][1][0], {nnkSym, nnkIdent} + if params[1][1][0].strVal != "NumContextProc": + error("The function annotated with `{.genInterp.}` does not take a `NumContextProc` as the firs argument.") + +proc replaceNumCtxArg(fn: NimNode): NimNode = + ## Checks the first argument of the given proc is indeed a `NumContextProc` argument. + ## MUST run `checkArgNumContext` on `fn` first. + ## + ## It returns the identifier of the first argument. + var params = fn.params # see `checkArgNNumContext` + expectKind params[1][0], {nnkSym, nnkIdent} + result = ident(params[1][0].strVal) + params[1] = nnkIdentDefs.newTree( + result, + nnkBracketExpr.newTree( + ident"InterpolatorType", + ident"T" + ), + newEmptyNode() + ) + fn.params = params + +proc untype(n: NimNode): NimNode = + case n.kind + of nnkSym: result = ident(n.strVal) + of nnkIdent: result = n + else: + error("Cannot untype the argument: " & $n.treerepr) + +proc genOriginalCall(fn: NimNode, ncp: NimNode): NimNode = + ## Generates a call to the original procedure `fn` with `ncp` + ## as the first argument + let fnName = fn.name + let params = fn.params + # extract all arguments we need to pass from `params` + var p = newSeq[NimNode]() + p.add ncp + for i in 2 ..< params.len: # first param is return type, second is parameter we replace + expectKind params[i], nnkIdentDefs + if params[i].len in 0 .. 2: + error("Invalid parameter: " & $params[i].treerepr) + else: # one or more arg of this type + # IdentDefs <- Example with 2 arguments of the same type + # Ident "xStart" <- index `0` + # Ident "xEnd" <- index `len - 3 = 4 - 3 = 1` + # Ident "float" + # Empty + for j in 0 .. params[i].len - 3: + p.add untype(params[i][j]) + # generate the call + result = nnkCall.newTree(fnName) + for el in p: + result.add el + +macro genInterp*(fn: untyped): untyped = + ## Takes a `proc` with a `NumContextProc` parameter as the first argument + ## and returns two procedures: + ## 1. The original proc + ## 2. An overload, which converts an `InterpolatorType[T]` argument to a + ## `NumContextProc[T, float]` using the conversion proc. + doAssert fn.kind in {nnkProcDef, nnkFuncDef} + result = newStmtList(fn) + # 1. check arg + checkArgNumContext(fn) + # 2. generate overload + var new = fn.copyNimTree() + # 2a. replace first argument by `InterpolatorType[T]` + let arg = new.replaceNumCtxArg() + # 2b. add body with NumContextProc + let ncpIdent = ident"ncp" + new.body = quote do: + mixin eval # defined in `interpolate`, but macro used in `integrate` + let `ncpIdent` = proc(x: float, ctx: NumContext[T, float]): T = eval(`arg`, x) + # 2c. add call to original proc + new.body.add genOriginalCall(fn, ncpIdent) + # 3. finalize + result.add new