Skip to content

Commit

Permalink
clean up set_default_scalar command
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Mar 12, 2024
1 parent cc84860 commit 9ecd354
Show file tree
Hide file tree
Showing 12 changed files with 81 additions and 141 deletions.
8 changes: 4 additions & 4 deletions SciLean/Core/FunctionSpaces/ContCDiffMap.lean
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ macro X:term:25 " ⟿[" K:term "," n:term "] " Y:term:26 : term =>
`(ContCDiffMap $K $n $X $Y)

macro X:term:25 " ⟿[" n:term "] " Y:term:26 : term =>
`(ContCDiffMap currentScalar% $n $X $Y)
`(ContCDiffMap defaultScalar% $n $X $Y)

macro X:term:25 " ⟿ " Y:term:26 : term =>
`(ContCDiffMap currentScalar% ∞ $X $Y)
`(ContCDiffMap defaultScalar% ∞ $X $Y)

@[app_unexpander ContCDiffMap] def unexpandContCDiffMap : Lean.PrettyPrinter.Unexpander
| `($(_) $R $n $X $Y) => `($X ⟿[$R,$n] $Y)
Expand Down Expand Up @@ -77,11 +77,11 @@ macro "fun " x:funBinder " ⟿[" K:term "," n:term "] " b:term : term =>

open Lean Parser Term
macro "fun " x:funBinder " ⟿[" n:term "] " b:term : term =>
`(ContCDiffMap.mk' currentScalar% $n (fun $x => $b) (by fun_prop (disch:=norm_num; linarith)))
`(ContCDiffMap.mk' defaultScalar% $n (fun $x => $b) (by fun_prop (disch:=norm_num; linarith)))

open Lean Parser Term
macro "fun " x:funBinder " ⟿ " b:term : term =>
`(ContCDiffMap.mk' currentScalar% ∞ (fun $x => $b) (by fun_prop (disch:=norm_num; linarith)))
`(ContCDiffMap.mk' defaultScalar% ∞ (fun $x => $b) (by fun_prop (disch:=norm_num; linarith)))

@[app_unexpander ContCDiffMap.mk'] def unexpandContCDiffMapMk : Lean.PrettyPrinter.Unexpander

Expand Down
8 changes: 4 additions & 4 deletions SciLean/Core/FunctionSpaces/ContCDiffMapFD.lean
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ macro X:term:25 " ⟿FD[" K:term "," n:term "] " Y:term:26 : term =>
`(ContCDiffMapFD $K $n $X $Y)

macro X:term:25 " ⟿FD[" n:term "] " Y:term:26 : term =>
`(ContCDiffMapFD currentScalar% $n $X $Y)
`(ContCDiffMapFD defaultScalar% $n $X $Y)

macro X:term:25 " ⟿FD " Y:term:26 : term =>
`(ContCDiffMapFD currentScalar% ∞ $X $Y)
`(ContCDiffMapFD defaultScalar% ∞ $X $Y)

@[app_unexpander ContCDiffMapFD] def unexpandContCDiffMapFD : Lean.PrettyPrinter.Unexpander
| `($(_) $R $n $X $Y) => `($X ⟿FD[$R,$n] $Y)
Expand All @@ -63,8 +63,8 @@ macro "fun " x:funBinder " ⟿FD[" K:term "," n:term "] " b:term : term =>
`(ContCDiffMapFD.mk' $K $n (fun $x => $b) ((fwdDeriv $K fun $x => $b) rewrite_by autodiff /- check that derivative has been eliminated -/) (sorry_proof /- todo: add proof -/) sorry_proof)


macro "fun " x:funBinder " ⟿FD[" n:term "] " b:term : term => `(fun $x ⟿FD[currentScalar%, $n] $b)
macro "fun " x:funBinder " ⟿FD " b:term : term => `(fun $x ⟿FD[currentScalar%, ∞] $b)
macro "fun " x:funBinder " ⟿FD[" n:term "] " b:term : term => `(fun $x ⟿FD[defaultScalar%, $n] $b)
macro "fun " x:funBinder " ⟿FD " b:term : term => `(fun $x ⟿FD[defaultScalar%, ∞] $b)

variable {K n}

Expand Down
3 changes: 3 additions & 0 deletions SciLean/Core/FunctionSpaces/SmoothLinearMap.lean
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ instance : FunLike (SmoothLinearMap K X Y) X Y where
macro X:term:25 " ⊸[" K:term "]" Y:term:26 : term =>
`(SmoothLinearMap $K $X $Y)

macro X:term:25 " ⊸ " Y:term:26 : term =>
`(SmoothLinearMap defaultScalar% $X $Y)

@[fun_prop]
theorem SmoothLinearMap_apply_right (f : X ⊸[K] Y) : IsSmoothLinearMap K (fun x => f x) := f.2

Expand Down
18 changes: 0 additions & 18 deletions SciLean/Core/Notation/Autodiff.lean

This file was deleted.

38 changes: 16 additions & 22 deletions SciLean/Core/Notation/CDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,55 +5,52 @@ import SciLean.Tactic.Autodiff
-- Notation -------------------------------------------------------------------
--------------------------------------------------------------------------------

namespace SciLean.NotationOverField
namespace SciLean.Notation

syntax diffBinderType := " : " term
syntax diffBinderValue := ":=" term (";" term)?
syntax diffBinder := ident (diffBinderType <|> diffBinderValue)?

scoped syntax "∂ " term:66 : term
scoped syntax "∂ " diffBinder ", " term:66 : term
scoped syntax "∂ " "(" diffBinder ")" ", " term:66 : term
syntax "∂ " term:66 : term
syntax "∂ " diffBinder ", " term:66 : term
syntax "∂ " "(" diffBinder ")" ", " term:66 : term

scoped syntax "∂! " term:66 : term
scoped syntax "∂! " diffBinder ", " term:66 : term
scoped syntax "∂! " "(" diffBinder ")" ", " term:66 : term
syntax "∂! " term:66 : term
syntax "∂! " diffBinder ", " term:66 : term
syntax "∂! " "(" diffBinder ")" ", " term:66 : term

open Lean Elab Term Meta in
elab_rules : term
| `(∂ $f $x $xs*) => do
let K := mkIdent (← currentFieldName.get)
let KExpr ← elabTerm (← `($K)) none
let K ← elabTerm (← `(defaultScalar%)) none
let X ← inferType (← elabTerm x none)
let Y ← mkFreshTypeMVar
let XY ← mkArrow X Y
-- X might also be infered by the function `f`
let fExpr ← withoutPostponing <| elabTermEnsuringType f XY false
let .some (X,_) := (← inferType fExpr).arrow? | return ← throwUnsupportedSyntax
if (← isDefEq KExpr X) && xs.size = 0 then
elabTerm (← `(scalarCDeriv $K $f $x $xs*)) none false
if (← isDefEq K X) && xs.size = 0 then
elabTerm (← `(scalarCDeriv defaultScalar% $f $x $xs*)) none false
else
elabTerm (← `(cderiv $K $f $x $xs*)) none false
elabTerm (← `(cderiv defaultScalar% $f $x $xs*)) none false

| `(∂ $f) => do
let K := mkIdent (← currentFieldName.get)
let K ← elabTerm (← `(defaultScalar%)) none
let X ← mkFreshTypeMVar
let Y ← mkFreshTypeMVar
let XY ← mkArrow X Y
let KExpr ← elabTerm (← `($K)) none
let fExpr ← withoutPostponing <| elabTermEnsuringType f XY false
if let .some (X,_) := (← inferType fExpr).arrow? then
if (← isDefEq KExpr X) then
elabTerm (← `(scalarCDeriv $K $f)) none false
if (← isDefEq K X) then
elabTerm (← `(scalarCDeriv defaultScalar% $f)) none false
else
elabTerm (← `(cderiv $K $f)) none false
elabTerm (← `(cderiv defaultScalar% $f)) none false
else
throwUnsupportedSyntax

-- in this case we do not want to call scalarCDeriv
| `(∂ $x:ident := $val:term ; $dir:term, $b) => do
let K := mkIdent (← currentFieldName.get)
elabTerm (← `(cderiv $K (fun $x => $b) $val $dir)) none
elabTerm (← `(cderiv defaultScalar% (fun $x => $b) $val $dir)) none


macro_rules
Expand Down Expand Up @@ -122,6 +119,3 @@ macro_rules
| _ => `(∂ $f)

| _ => throw ()


end NotationOverField
22 changes: 9 additions & 13 deletions SciLean/Core/Notation/FwdDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,24 @@ import SciLean.Core.Notation.CDeriv
-- Notation -------------------------------------------------------------------
--------------------------------------------------------------------------------

namespace SciLean.NotationOverField
namespace SciLean.Notation


scoped syntax "∂> " term:66 : term
scoped syntax "∂> " diffBinder ", " term:66 : term
scoped syntax "∂> " "(" diffBinder ")" ", " term:66 : term
syntax "∂> " term:66 : term
syntax "∂> " diffBinder ", " term:66 : term
syntax "∂> " "(" diffBinder ")" ", " term:66 : term

scoped syntax "∂>! " term:66 : term
scoped syntax "∂>! " diffBinder ", " term:66 : term
scoped syntax "∂>! " "(" diffBinder ")" ", " term:66 : term
syntax "∂>! " term:66 : term
syntax "∂>! " diffBinder ", " term:66 : term
syntax "∂>! " "(" diffBinder ")" ", " term:66 : term

open Lean Elab Term Meta in
elab_rules : term
| `(∂> $f $x $xs*) => do
let K := mkIdent (← currentFieldName.get)
elabTerm (← `(fwdDeriv $K $f $x $xs*)) none
elabTerm (← `(fwdDeriv defaultScalar% $f $x $xs*)) none

| `(∂> $f) => do
let K := mkIdent (← currentFieldName.get)
elabTerm (← `(fwdDeriv $K $f)) none
elabTerm (← `(fwdDeriv defaultScalar% $f)) none


macro_rules
Expand Down Expand Up @@ -67,5 +65,3 @@ macro_rules
| _ => `(∂> $f)

| _ => throw ()

end NotationOverField
34 changes: 15 additions & 19 deletions SciLean/Core/Notation/Gradient.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,46 +3,44 @@ import SciLean.Core.Notation.CDeriv
import SciLean.Tactic.Autodiff


namespace SciLean.NotationOverField
namespace SciLean.Notation

scoped syntax (name:=gradNotation1) "∇ " term:66 : term
scoped syntax "∇ " diffBinder ", " term:66 : term
scoped syntax "∇ " "(" diffBinder ")" ", " term:66 : term
syntax (name:=gradNotation1) "∇ " term:66 : term
syntax "∇ " diffBinder ", " term:66 : term
syntax "∇ " "(" diffBinder ")" ", " term:66 : term

scoped syntax "∇! " term:66 : term
scoped syntax "∇! " diffBinder ", " term:66 : term
scoped syntax "∇! " "(" diffBinder ")" ", " term:66 : term
syntax "∇! " term:66 : term
syntax "∇! " diffBinder ", " term:66 : term
syntax "∇! " "(" diffBinder ")" ", " term:66 : term


open Lean Elab Term Meta in
elab_rules (kind:=gradNotation1) : term
| `(∇ $f $x $xs*) => do
let K := mkIdent (← currentFieldName.get)
let KExpr ← elabTerm (← `($K)) none
let K ← elabTerm (← `(defaultScalar%)) none
let X ← inferType (← elabTerm x none)
let Y ← mkFreshTypeMVar
let XY ← mkArrow X Y
-- Y might also be infered by the function `f`
let fExpr ← withoutPostponing <| elabTermEnsuringType f XY false
let .some (_,Y) := (← inferType fExpr).arrow?
| return ← throwUnsupportedSyntax
if (← isDefEq KExpr Y) then
elabTerm (← `(scalarGradient $K $f $x $xs*)) none false
if (← isDefEq K Y) then
elabTerm (← `(scalarGradient defaultScalar% $f $x $xs*)) none false
else
elabTerm (← `(gradient $K $f $x $xs*)) none false
elabTerm (← `(gradient defaultScalar% $f $x $xs*)) none false

| `(∇ $f) => do
let K := mkIdent (← currentFieldName.get)
let K ← elabTerm (← `(defaultScalar%)) none
let X ← mkFreshTypeMVar
let Y ← mkFreshTypeMVar
let XY ← mkArrow X Y
let KExpr ← elabTerm (← `($K)) none
let fExpr ← withoutPostponing <| elabTermEnsuringType f XY false
if let .some (_,Y) := (← inferType fExpr).arrow? then
if (← isDefEq KExpr Y) then
elabTerm (← `(scalarGradient $K $f)) none false
if (← isDefEq K Y) then
elabTerm (← `(scalarGradient defaultScalar% $f)) none false
else
elabTerm (← `(gradient $K $f)) none false
elabTerm (← `(gradient defaultScalar% $f)) none false
else
throwUnsupportedSyntax

Expand Down Expand Up @@ -107,5 +105,3 @@ macro_rules
| _ => `(∇ $f)

| _ => throw ()

end SciLean.NotationOverField
22 changes: 9 additions & 13 deletions SciLean/Core/Notation/RevCDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,22 @@ import SciLean.Core.FunctionTransformations.RevDeriv
-- Notation -------------------------------------------------------------------
--------------------------------------------------------------------------------

namespace SciLean.NotationOverField
namespace SciLean.Notation

scoped syntax "<∂ " term:66 : term
scoped syntax "<∂ " diffBinder ", " term:66 : term
scoped syntax "<∂ " "(" diffBinder ")" ", " term:66 : term
syntax "<∂ " term:66 : term
syntax "<∂ " diffBinder ", " term:66 : term
syntax "<∂ " "(" diffBinder ")" ", " term:66 : term

scoped syntax "<∂! " term:66 : term
scoped syntax "<∂! " diffBinder ", " term:66 : term
scoped syntax "<∂! " "(" diffBinder ")" ", " term:66 : term
syntax "<∂! " term:66 : term
syntax "<∂! " diffBinder ", " term:66 : term
syntax "<∂! " "(" diffBinder ")" ", " term:66 : term

open Lean Elab Term Meta in
elab_rules : term
| `(<∂ $f $xs*) => do
let K := mkIdent (← currentFieldName.get)
elabTerm (← `(revDeriv $K $f $xs*)) none
elabTerm (← `(revDeriv defaultScalar% $f $xs*)) none
| `(<∂ $f) => do
let K := mkIdent (← currentFieldName.get)
elabTerm (← `(revDeriv $K $f)) none
elabTerm (← `(revDeriv defaultScalar% $f)) none
-- | `(<∂ $x:ident := $val:term ; $codir:term, $b) => do
-- let K := mkIdent (← currentFieldName.get)
-- elabTerm (← `(revDerivEval $K (fun $x => $b) $val $codir)) none
Expand Down Expand Up @@ -69,5 +67,3 @@ macro_rules
-- | _ => throw ()

-- | _ => throw ()

end NotationOverField
23 changes: 6 additions & 17 deletions SciLean/Core/NotationOverField.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,20 @@ open Lean Elab Command Term
namespace SciLean
namespace NotationOverField

syntax "defaultScalar%" : term

initialize currentFieldName : IO.Ref Name ← IO.mkRef default

elab "open_notation_over_field" K:ident : command => do
currentFieldName.set K.getId
Lean.Elab.Command.elabCommand <| ←
`(open SciLean.NotationOverField)

macro "set_default_scalar " K:ident : command => `(open_notation_over_field $K)


syntax "currentScalar%" : term

macro_rules | `(currentScalar%) => Lean.Macro.throwError "\
macro_rules | `(defaultScalar%) => Lean.Macro.throwError "\
This expression is using notation requiring to know the default scalar type. \
To set it, add this command somewhere to your file \
\n\n set_current_scalar R\
\n\n set_default_scalar R\
\n\n\
where R is the desired scalar, usually ℝ or Float.\n\n\
If you want to write code that is agnostic about the particular scalar \
type then introduce a generic type R with instance of `RealScalar R` \
\n\n variable {R : Type _} [RealScalar R]\
\n set_current_scalar R\
\n set_default_scalar R\
\n\n\
TODO: write a doc page about writing field polymorphic code and add a link here"

macro "set_current_scalar" r:term : command =>
`(local macro_rules | `(currentScalar%) => `($r))
macro "set_default_scalar" r:term : command =>
`(local macro_rules | `(defaultScalar%) => `($r))
3 changes: 1 addition & 2 deletions SciLean/Core/Objects/IsReal.lean
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ open Lean Elab Term
WARRNING: This is override for normal norm notation that provides computable version of norm if available.
-/
scoped elab (priority := high) "‖" x:term "‖" : term => do
let K := mkIdent (← currentFieldName.get)
elabTerm (← `(cnorm (R:=$K) $x)) none
elabTerm (← `(cnorm (R:=defaultScalar%) $x)) none
-- TODO: fall back to normal norm if

end NotationOverField
Loading

0 comments on commit 9ecd354

Please sign in to comment.