diff --git a/SciLean/Core/Notation/FwdCDeriv.lean b/SciLean/Core/Notation/FwdCDeriv.lean index c23e7534..844293b0 100644 --- a/SciLean/Core/Notation/FwdCDeriv.lean +++ b/SciLean/Core/Notation/FwdCDeriv.lean @@ -8,19 +8,26 @@ import SciLean.Core.FunctionTransformations.FwdCDeriv namespace SciLean.NotationOverField -scoped syntax "∂> " term:66 : term +scoped syntax "∂> " term+ : term scoped syntax "∂> " diffBinder ", " term:66 : term scoped syntax "∂> " "(" diffBinder ")" ", " term:66 : term +scoped syntax "∂>! " term+ : term +scoped syntax "∂>! " diffBinder ", " term:66 : term +scoped syntax "∂>! " "(" diffBinder ")" ", " term:66 : term + open Lean Elab Term Meta in elab_rules : term +| `(∂> $f $xs*) => do + let K := mkIdent (← currentFieldName.get) + elabTerm (← `(fwdCDeriv $K $f $xs*)) none + | `(∂> $f) => do let K := mkIdent (← currentFieldName.get) elabTerm (← `(fwdCDeriv $K $f)) none macro_rules -| `(∂> $f $xs*) => `((∂> $f) $xs*) | `(∂> $x:ident, $b) => `(∂> (fun $x => $b)) | `(∂> $x:ident := $val:term, $b) => `(∂> (fun $x => $b) $val) | `(∂> $x:ident : $type:term, $b) => `(∂> fun $x : $type => $b) @@ -28,6 +35,16 @@ macro_rules | `(∂> ($b:diffBinder), $f) => `(∂> $b, $f) +macro_rules +| `(∂>! $f $xs*) => `((∂> $f $xs*) rewrite_by ftrans; ftrans; ftrans) +| `(∂>! $f) => `((∂> $f) rewrite_by ftrans; ftrans; ftrans) +| `(∂>! $x:ident, $b) => `(∂>! (fun $x => $b)) +| `(∂>! $x:ident := $val:term, $b) => `(∂>! (fun $x => $b) $val) +| `(∂>! $x:ident : $type:term, $b) => `(∂>! fun $x : $type => $b) +| `(∂>! $x:ident := $val:term ; $dir:term, $b) => `(∂>! (fun $x => $b) $val $dir) +| `(∂>! ($b:diffBinder), $f) => `(∂>! $b, $f) + + @[app_unexpander fwdCDeriv] def unexpandFwdCDeriv : Lean.PrettyPrinter.Unexpander | `($(_) $_ $f:term $x $dx) => diff --git a/SciLean/Core/Notation/Gradient.lean b/SciLean/Core/Notation/Gradient.lean index f9d1ece4..5194089f 100644 --- a/SciLean/Core/Notation/Gradient.lean +++ b/SciLean/Core/Notation/Gradient.lean @@ -5,20 +5,39 @@ import SciLean.Core.Notation.CDeriv namespace SciLean.NotationOverField -scoped syntax (name:=gradNotation1) "∇ " term:66 : term +scoped syntax (name:=gradNotation1) "∇ " term+ : term scoped syntax "∇ " diffBinder ", " term:66 : term scoped syntax "∇ " "(" diffBinder ")" ", " term:66 : term -scoped syntax "∇! " term:66 : term + +scoped syntax "∇! " term+ : term scoped syntax "∇! " diffBinder ", " term:66 : term scoped syntax "∇! " "(" diffBinder ")" ", " term:66 : term open Lean Elab Term Meta in elab_rules (kind:=gradNotation1) : term +| `(∇ $f $x $xs*) => do + let K := mkIdent (← currentFieldName.get) + let KExpr ← elabTerm (← `($K)) none + let X ← inferType (← elabTerm x none) + let Y ← mkFreshTypeMVar + let XY ← mkArrow X Y + -- Y might also be infered by the function `f` + let fExpr ← withoutPostponing <| elabTermEnsuringType f XY false + let .some (_,Y) := (← inferType fExpr).arrow? + | return ← throwUnsupportedSyntax + if (← isDefEq KExpr Y) then + elabTerm (← `(scalarGradient $K $f $x $xs*)) none false + else + elabTerm (← `(gradient $K $f $x $xs*)) none false + | `(∇ $f) => do let K := mkIdent (← currentFieldName.get) + let X ← mkFreshTypeMVar + let Y ← mkFreshTypeMVar + let XY ← mkArrow X Y let KExpr ← elabTerm (← `($K)) none - let fExpr ← elabTerm f none + let fExpr ← withoutPostponing <| elabTermEnsuringType f XY false if let .some (_,Y) := (← inferType fExpr).arrow? then if (← isDefEq KExpr Y) then elabTerm (← `(scalarGradient $K $f)) none false @@ -27,24 +46,19 @@ elab_rules (kind:=gradNotation1) : term else throwUnsupportedSyntax --- open Lean Elab Term Meta in --- elab_rules (kind:=gradNotation1) : term --- | `(∇ $x:ident := $val:term; $codir:term, $b) => do --- let K := mkIdent (← currentFieldName.get) --- elabTerm (← `(gradient $K (fun $x => $b) $val $codir)) none false - macro_rules | `(∇ $x:ident, $f) => `(∇ fun $x => $f) | `(∇ $x:ident : $type:term, $f) => `(∇ fun $x : $type => $f) -| `(∇ $x:ident := $val:term, $f) => `((∇ fun $x => $f) $val) +| `(∇ $x:ident := $val:term, $f) => `(∇ (fun $x => $f) $val) | `(∇ ($b:diffBinder), $f) => `(∇ $b, $f) -| `(∇! $f) => `((∇ $f) rewrite_by autodiff) + +macro_rules +| `(∇! $f) => `((∇ $f) rewrite_by ftrans; ftrans; ftrans) | `(∇! $x:ident, $f) => `(∇! fun $x => $f) | `(∇! $x:ident : $type:term, $f) => `(∇! fun $x : $type => $f) -| `(∇! $x:ident := $val:term, $f) => `((∇! fun $x => $f) $val) +| `(∇! $x:ident := $val:term, $f) => `(∇! (fun $x => $f) $val) | `(∇! ($b:diffBinder), $f) => `(∇! $b, $f) - @[app_unexpander gradient] def unexpandGradient : Lean.PrettyPrinter.Unexpander | `($(_) $_ $f:term $x $dy $z $zs*) => diff --git a/SciLean/Core/Notation/RevCDeriv.lean b/SciLean/Core/Notation/RevCDeriv.lean index f234c509..93527d4f 100644 --- a/SciLean/Core/Notation/RevCDeriv.lean +++ b/SciLean/Core/Notation/RevCDeriv.lean @@ -1,19 +1,26 @@ import SciLean.Core.Notation.CDeriv import SciLean.Core.FunctionTransformations.RevCDeriv + -------------------------------------------------------------------------------- -- Notation ------------------------------------------------------------------- -------------------------------------------------------------------------------- namespace SciLean.NotationOverField - -scoped syntax "<∂ " term:66 : term +scoped syntax "<∂ " term+ : term scoped syntax "<∂ " diffBinder ", " term:66 : term scoped syntax "<∂ " "(" diffBinder ")" ", " term:66 : term +scoped syntax "<∂! " term+ : term +scoped syntax "<∂! " diffBinder ", " term:66 : term +scoped syntax "<∂! " "(" diffBinder ")" ", " term:66 : term + open Lean Elab Term Meta in elab_rules : term +| `(<∂ $f $xs*) => do + let K := mkIdent (← currentFieldName.get) + elabTerm (← `(revCDeriv $K $f $xs*)) none | `(<∂ $f) => do let K := mkIdent (← currentFieldName.get) elabTerm (← `(revCDeriv $K $f)) none @@ -22,12 +29,19 @@ elab_rules : term elabTerm (← `(revCDerivEval $K (fun $x => $b) $val $codir)) none macro_rules -| `(<∂ $f $xs*) => `((<∂ $f) $xs*) | `(<∂ $x:ident, $b) => `(<∂ (fun $x => $b)) | `(<∂ $x:ident := $val:term, $b) => `(<∂ (fun $x => $b) $val) | `(<∂ $x:ident : $type:term, $b) => `(<∂ fun $x : $type => $b) | `(<∂ ($b:diffBinder), $f) => `(<∂ $b, $f) +macro_rules +| `(<∂! $f $xs*) => `((<∂ $f $xs*) rewrite_by ftrans; ftrans; ftrans) +| `(<∂! $f) => `((<∂ $f) rewrite_by ftrans; ftrans; ftrans) +| `(<∂! $x:ident, $b) => `(<∂! (fun $x => $b)) +| `(<∂! $x:ident := $val:term, $b) => `(<∂! (fun $x => $b) $val) +| `(<∂! $x:ident : $type:term, $b) => `(<∂! fun $x : $type => $b) +| `(<∂! ($b:diffBinder), $f) => `(<∂! $b, $f) + @[app_unexpander revCDeriv] def unexpandRevCDeriv : Lean.PrettyPrinter.Unexpander diff --git a/test/cderiv_notation.lean b/test/cderiv_notation.lean deleted file mode 100644 index adfbf2ec..00000000 --- a/test/cderiv_notation.lean +++ /dev/null @@ -1,25 +0,0 @@ -import SciLean.Core.Notation.CDeriv - -open SciLean - -variable {K} [IsROrC K] - -set_default_scalar K - -#check ∂ (fun x : K => x*x) -#check ∂ (fun x => x*x) 1 -#check ∂ (x:=(1:K)), x*x -#check ∂ (x:=1), x*x -#check ∂ (x:=0.1), x*x -#check ∂ (x:=((1:K),(2:K))), (x + x) - -#check ∂! (fun x : K => x^2) -#check ∂! (fun x : K×K => x + x) -#check ∂! (fun x => x*x) 1 -#check ∂! (x:=((1:K),(2:K))), (x + x) -#check ∂! (x:=1), x*x - - -variable {X} [Vec K X] (f : X → X) - -#check ∂ (x:=0), f x diff --git a/test/deriv_notation.lean b/test/deriv_notation.lean new file mode 100644 index 00000000..ad68e541 --- /dev/null +++ b/test/deriv_notation.lean @@ -0,0 +1,82 @@ +import SciLean.Core.Notation.CDeriv +import SciLean.Core.Notation.Gradient +import SciLean.Core.Notation.FwdCDeriv +import SciLean.Core.Notation.RevCDeriv +import SciLean.Core.FloatAsReal + + +open SciLean + +variable {K} [RealScalar K] + +set_default_scalar K + +#check ∂ (fun x : K => x*x) +#check ∂ (fun x => x*x) 1 +#check ∂ (x:=(1:K)), x*x +#check ∂ (x:=1), x*x +#check ∂ (x:=0.1), x*x +#check ∂ (x:=((1:K),(2:K))), (x + x) + +#check ∂! (fun x : K => x^2) +#check ∂! (fun x : K×K => x + x) +#check ∂! (fun x => x*x) 1 +#check ∂! (x:=((1:K),(2:K))), (x + x) +#check ∂! (x:=1), x*x + +variable {X} [Vec K X] (f : X → X) + +#check ∂ (x:=0), f x + +set_default_scalar Float + +#eval ∂! (fun x => x^2) 1 +#eval ∂! (fun x => x*x) 1 +#eval ∂! (x:=1), x*x +#eval ∂! (fun x => x + x) (1.0,2.0) (1.0,0.0) +#eval ∂! (x:=(1.0,2.0);(1.0,0.0)), (x + x) + + + +-------------------------------------------------------------------------------- + +set_default_scalar K + +#check ∇ x : (K×K), x.1 +#check ∇! x : (K×K), x.1 +#check ∇! x : (K×K), x.2 + +variable (y : K × K) + +#check ∇ (x:=y), (x + x) +#check ∇ (fun x => x + x) y +#check ∇ (fun x => x + x) ((1.0,2.0) : K×K) +#check (∇! x : (K×K), ⟪x,(1,0)⟫) + + +set_default_scalar Float + +#eval ∇! (fun x => x^2) 1 +#eval ∇! (fun x => x*x) 1 +#eval ∇! (x:=1), x*x +#eval ∇! (fun x : Float×Float => (x + x).2) (1.0,2.0) +#eval ∇! (x:=((1.0 : Float),(2.0:Float))), (x + x).1 + + + +-------------------------------------------------------------------------------- + +set_default_scalar K + +#check ∂>! x : K×K, (x.1 + x.2*x.1) +#check ∂>! x:=(1:K);2, (x + x*x) + + + +-------------------------------------------------------------------------------- + +set_default_scalar K + +#check <∂! x : K×K, (x.1 + x.2*x.1) +#check <∂! x:=(1:K), (x + x*x) +