Skip to content

Commit

Permalink
add genInterp macro pragma to generate overloads for `InterpolatorT…
Browse files Browse the repository at this point in the history
…ype`

This takes the place of the previously automagical `converter`.
  • Loading branch information
Vindaar committed Sep 13, 2024
1 parent 6d68162 commit b247975
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 14 deletions.
33 changes: 19 additions & 14 deletions src/numericalnim/integrate.nim
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@ import arraymancer

from ./interpolate import InterpolatorType, newHermiteSpline

# to annotate procedures with `{.genInterp.}` to generate `InterpolatorType` overloads
import private/macro_utils

## # Integration
## This module implements various integration routines.
## It provides:
##
##
## ## Integrate discrete data:
## - `trapz`, `simpson`: works for any spacing between points.
## - `romberg`: requires equally spaced points and the number of points must be of the form 2^k + 1 ie 3, 5, 9, 17, 33, 65, 129 etc.
Expand All @@ -27,7 +30,7 @@ runnableExamples:
## It also handles infinite integration limits.
## - `gaussQuad`: Fixed step size Gaussian quadrature.
## - `romberg`: Adaptive method based on Richardson Extrapolation.
## - `adaptiveSimpson`: Adaptive step size.
## - `adaptiveSimpson`: Adaptive step size.
## - `simpson`: Fixed step size.
## - `trapz`: Fixed step size.

Expand All @@ -36,7 +39,7 @@ runnableExamples:

proc f(x: float, ctx: NumContext[float, float]): float =
exp(x)

let a = 0.0
let b = Inf
let integral = adaptiveGauss(f, a, b)
Expand Down Expand Up @@ -74,10 +77,9 @@ type
IntervalList[T; U; V] = object
list: seq[IntervalType[T, U, V]] # contains all the intervals sorted from smallest to largest error


# N: #intervals
proc trapz*[T](f: NumContextProc[T, float], xStart, xEnd: float,
N = 500, ctx: NumContext[T, float] = nil): T =
N = 500, ctx: NumContext[T, float] = nil): T {.genInterp.} =
## Calculate the integral of f using the trapezoidal rule.
##
## Input:
Expand Down Expand Up @@ -174,7 +176,7 @@ proc cumtrapz*[T](f: NumContextProc[T, float], X: openArray[float],


proc simpson*[T](f: NumContextProc[T, float], xStart, xEnd: float,
N = 500, ctx: NumContext[T, float] = nil): T =
N = 500, ctx: NumContext[T, float] = nil): T {.genInterp.} =
## Calculate the integral of f using Simpson's rule.
##
## Input:
Expand Down Expand Up @@ -252,7 +254,7 @@ proc simpson*[T](Y: openArray[T], X: openArray[float]): T =
result += alpha * ySorted[2*i + 2] + beta * ySorted[2*i + 1] + eta * ySorted[2*i]

proc adaptiveSimpson*[T](f: NumContextProc[T, float], xStart, xEnd: float,
tol = 1e-8, ctx: NumContext[T, float] = nil): T =
tol = 1e-8, ctx: NumContext[T, float] = nil): T {.genInterp.} =
## Calculate the integral of f using an adaptive Simpson's rule.
##
## Input:
Expand Down Expand Up @@ -284,7 +286,7 @@ proc adaptiveSimpson*[T](f: NumContextProc[T, float], xStart, xEnd: float,
return left + right

proc internal_adaptiveSimpson[T](f: NumContextProc[T, float], xStart, xEnd: float,
tol: float, ctx: NumContext[T, float], reused_points: array[3, T]): T =
tol: float, ctx: NumContext[T, float], reused_points: array[3, T]): T {.genInterp.} =
let zero = reused_points[0] - reused_points[0]
let dx1 = (xEnd - xStart) / 2
let dx2 = (xEnd - xStart) / 4
Expand All @@ -302,7 +304,7 @@ proc internal_adaptiveSimpson[T](f: NumContextProc[T, float], xStart, xEnd: floa
return left + right

proc adaptiveSimpson2*[T](f: NumContextProc[T, float], xStart, xEnd: float,
tol = 1e-8, ctx: NumContext[T, float] = nil): T =
tol = 1e-8, ctx: NumContext[T, float] = nil): T {.genInterp.} =
## Calculate the integral of f using an adaptive Simpson's rule.
##
## Input:
Expand Down Expand Up @@ -399,7 +401,7 @@ proc cumsimpson*[T](f: NumContextProc[T, float], X: openArray[float],
result = hermiteInterpolate(X, t, ys, dy)

proc romberg*[T](f: NumContextProc[T, float], xStart, xEnd: float,
depth = 8, tol = 1e-8, ctx: NumContext[T, float] = nil): T =
depth = 8, tol = 1e-8, ctx: NumContext[T, float] = nil): T {.genInterp.} =
## Calculate the integral of f using Romberg Integration.
##
## Input:
Expand Down Expand Up @@ -594,7 +596,7 @@ proc getGaussLegendreWeights(nPoints: int): tuple[nodes: seq[float], weights: se
return gaussWeights[nPoints]

proc gaussQuad*[T](f: NumContextProc[T, float], xStart, xEnd: float,
N = 100, nPoints = 7, ctx: NumContext[T, float] = nil): T =
N = 100, nPoints = 7, ctx: NumContext[T, float] = nil): T {.genInterp.} =
## Calculate the integral of f using Gaussian Quadrature.
## Has 20 different sets of weights, ranging from 1 to 20 function evaluations per subinterval.
##
Expand Down Expand Up @@ -654,7 +656,7 @@ proc calcGaussKronrod[T; U](f: NumContextProc[T, U], xStart, xEnd: U, ctx: NumCo


proc adaptiveGaussLocal*[T](f: NumContextProc[T, float],
xStart, xEnd: float, tol = 1e-8, ctx: NumContext[T, float] = nil): T =
xStart, xEnd: float, tol = 1e-8, ctx: NumContext[T, float] = nil): T {.genInterp.} =
## Calculate the integral of f using an locally adaptive Gauss-Kronrod Quadrature.
##
## Input:
Expand Down Expand Up @@ -854,7 +856,7 @@ template adaptiveGaussImpl(): untyped {.dirty.} =
totalValue += highValue

proc adaptiveGauss*[T; U](f_in: NumContextProc[T, U],
xStart_in, xEnd_in: U, tol = 1e-8, initialPoints: openArray[U] = @[], maxintervals: int = 10000, ctx: NumContext[T, U] = nil): T =
xStart_in, xEnd_in: U, tol = 1e-8, initialPoints: openArray[U] = @[], maxintervals: int = 10000, ctx: NumContext[T, U] = nil): T {.genInterp.} =
## Calculate the integral of f using an globally adaptive Gauss-Kronrod Quadrature. Inf and -Inf can be used as integration limits.
##
## Input:
Expand Down Expand Up @@ -909,7 +911,10 @@ proc cumGaussSpline*[T; U](f_in: NumContextProc[T, U],
result = newHermiteSpline[T](xs, ys)

proc cumGauss*[T](f_in: NumContextProc[T, float],
X: openArray[float], tol = 1e-8, initialPoints: openArray[float] = @[], maxintervals: int = 10000, ctx: NumContext[T, float] = nil): seq[T] =
X: openArray[float], tol = 1e-8,
initialPoints: openArray[float] = @[],
maxintervals: int = 10000,
ctx: NumContext[T, float] = nil): seq[T] {.genInterp.} =
## Calculate the cumulative integral of f using an globally adaptive Gauss-Kronrod Quadrature.
## Returns a sequence of values which is the cumulative integral of f at the points defined in X.
## Important: because of the much higher order of the Gauss-Kronrod quadrature (order 21) compared to the interpolating Hermite spline (order 3) you have to give it a large amount of initialPoints.
Expand Down
93 changes: 93 additions & 0 deletions src/numericalnim/private/macro_utils.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import std / macros
proc checkArgNumContext(fn: NimNode) =
## Checks the first argument of the given proc is indeed a `NumContextProc` argument.
let params = fn.params
# FormalParams <- `.params`
# Ident "T"
# IdentDefs <- `params[1]`
# Sym "f"
# BracketExpr <- `params[1][1]`
# Sym "NumContextProc" <- `params[1][1][0]`
# Ident "T"
# Sym "float"
# Empty
expectKind params, nnkFormalParams
expectKind params[1], nnkIdentDefs
expectKind params[1][1], nnkBracketExpr
expectKind params[1][1][0], {nnkSym, nnkIdent}
if params[1][1][0].strVal != "NumContextProc":
error("The function annotated with `{.genInterp.}` does not take a `NumContextProc` as the firs argument.")

proc replaceNumCtxArg(fn: NimNode): NimNode =
## Checks the first argument of the given proc is indeed a `NumContextProc` argument.
## MUST run `checkArgNumContext` on `fn` first.
##
## It returns the identifier of the first argument.
var params = fn.params # see `checkArgNNumContext`
expectKind params[1][0], {nnkSym, nnkIdent}
result = ident(params[1][0].strVal)
params[1] = nnkIdentDefs.newTree(
result,
nnkBracketExpr.newTree(
ident"InterpolatorType",
ident"T"
),
newEmptyNode()
)
fn.params = params

proc untype(n: NimNode): NimNode =
case n.kind
of nnkSym: result = ident(n.strVal)
of nnkIdent: result = n
else:
error("Cannot untype the argument: " & $n.treerepr)

proc genOriginalCall(fn: NimNode, ncp: NimNode): NimNode =
## Generates a call to the original procedure `fn` with `ncp`
## as the first argument
let fnName = fn.name
let params = fn.params
# extract all arguments we need to pass from `params`
var p = newSeq[NimNode]()
p.add ncp
for i in 2 ..< params.len: # first param is return type, second is parameter we replace
expectKind params[i], nnkIdentDefs
if params[i].len in 0 .. 2:
error("Invalid parameter: " & $params[i].treerepr)
else: # one or more arg of this type
# IdentDefs <- Example with 2 arguments of the same type
# Ident "xStart" <- index `0`
# Ident "xEnd" <- index `len - 3 = 4 - 3 = 1`
# Ident "float"
# Empty
for j in 0 .. params[i].len - 3:
p.add untype(params[i][j])
# generate the call
result = nnkCall.newTree(fnName)
for el in p:
result.add el

macro genInterp*(fn: untyped): untyped =
## Takes a `proc` with a `NumContextProc` parameter as the first argument
## and returns two procedures:
## 1. The original proc
## 2. An overload, which converts an `InterpolatorType[T]` argument to a
## `NumContextProc[T, float]` using the conversion proc.
doAssert fn.kind in {nnkProcDef, nnkFuncDef}
result = newStmtList(fn)
# 1. check arg
checkArgNumContext(fn)
# 2. generate overload
var new = fn.copyNimTree()
# 2a. replace first argument by `InterpolatorType[T]`
let arg = new.replaceNumCtxArg()
# 2b. add body with NumContextProc
let ncpIdent = ident"ncp"
new.body = quote do:
mixin eval # defined in `interpolate`, but macro used in `integrate`
let `ncpIdent` = proc(x: float, ctx: NumContext[T, float]): T = eval(`arg`, x)
# 2c. add call to original proc
new.body.add genOriginalCall(fn, ncpIdent)
# 3. finalize
result.add new

0 comments on commit b247975

Please sign in to comment.