Skip to content

Commit

Permalink
array notation fix
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Mar 11, 2024
1 parent bbd3a39 commit c92f377
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 5 deletions.
9 changes: 6 additions & 3 deletions SciLean/Data/ArrayType/Notation.lean
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@ abbrev introElemNotation {Cont Idx Elem} [DecidableEq Idx] [ArrayType Cont Idx E
:= Indexed.ofFn (C := arrayTypeCont Idx Elem) f

open Lean.TSyntax.Compat in
macro "⊞ " x:term " => " b:term:51 : term => `(introElemNotation fun $x => $b)
macro "⊞ " x:term " : " X:term " => " b:term:51 : term => `(introElemNotation fun ($x : $X) => $b)
-- macro "⊞ " x:term " => " b:term:51 : term => `(introElemNotation fun $x => $b)
-- macro "⊞ " x:term " : " X:term " => " b:term:51 : term => `(introElemNotation fun ($x : $X) => $b)
open Term Function in
macro "⊞ " xs:funBinder* " => " b:term:51 : term => `(introElemNotation (HasUncurry.uncurry fun $xs* => $b))


@[app_unexpander introElemNotation]
def unexpandIntroElemNotation : Lean.PrettyPrinter.Unexpander
| `($(_) fun $x:term => $b) =>
| `($(_) fun $x => $b) =>
`(⊞ $x:term => $b)
| _ => throw ()

Expand Down
114 changes: 112 additions & 2 deletions SciLean/Doodle.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ variable
set_default_scalar R

-- fprop example
example : IsDifferentiable R fun x : R => x^2 := by fprop
example : CDifferentiable R fun x : R => x^2 := by simp (config:={zeta:=true})
-- ftrans example
example : ∂ x : R, x^2 = fun x => 2 * x := by conv => lhs; ftrans

Expand Down Expand Up @@ -45,8 +45,118 @@ def foo (x y : R) : R := x + y^2
prop_by unfold foo; fprop
trans_by unfold foo; ftrans

#print foo.arg_xy.IsDifferentiable_rule
#print foo.arg_xy.CDifferentiable_rule
#print foo.arg_xy.fwdCDeriv
#check foo.arg_xy.fwdCDeriv_rule

#check ∂>! x : R, foo x x




variable (n m : Nat) (x : Float^[n]) (y : Float^[m])

#check ⊞ i (j : Fin m) => (x[i] : Float)^j.1


#check ⊞ i j => x[i] * y[j]


#check introElemNotation (Function.HasUncurry.uncurry (fun ((i,j) : Fin n × Fin m) => (x[i] : Float)^j.1))

#check ↿(fun ((i,j) : Fin n × Fin m) => (x[i] : Float)^j.1)


#check ↿(fun i => (x[i] : Float))


#check LeanColls.Indexed.ofFn (C:=DataArrayN Float _) (Function.HasUncurry.uncurry fun i (j : Fin m) => (x[i] : Float)^j.1)

#check LeanColls.Indexed.ofFn (C:=DataArrayN Float _) (↿fun i (j : Fin m) => (x[i] : Float)^j.1)

open Lean Elab Term Meta

/-- Assuming `e = X₁ × ... Xₘ` this function returns `#[X₁, ..., Xₘ]`.
You can provide the expected number `n?` of elemnts then this function returns
`#[X₁, ..., (Xₙ × ... Xₘ)].
Returns none if `n? = 0` or `n? > m` i.e. `e` does not have enough terms.
-/
private partial def splitProdType (e : Expr) (n? : Option Nat := none) : Option (Array Expr) :=
if n? = .some 0 then
none
else
go e #[]
where
go (e : Expr) (xs : Array Expr) : Option (Array Expr) :=
if .some (xs.size + 1) = n? then
xs.push e
else
if e.isAppOfArity ``Prod 2 then
go (e.getArg! 1) (xs.push (e.getArg! 0))
else
if n?.isNone then
xs.push e
else
.none

private def mkProdElem (xs : Array Expr) : MetaM Expr :=
match xs.size with
| 0 => return default
| 1 => return xs[0]!
| _ =>
let n := xs.size
xs[0:n-1].foldrM (init:=xs[n-1]!) fun x p => mkAppM ``Prod.mk #[x,p]


/-- Turn an array of terms in into a tuple. -/
private def mkTuple (xs : Array (TSyntax `term)) : MacroM (TSyntax `term) :=
`(term| ($(xs[0]!), $(xs[1:]),*))


open Lean Elab LeanColls Indexed Notation Term Meta

syntax:max (name:=indexedGet) (priority:=high+1) term noWs "[" elemIndex,* "]" : term

@[term_elab indexedGet]
def elabFoo : Term.TermElab := fun stx expectedType? => do
match stx with
| `($x[$ids:elemIndex,*]) => do

IO.println "asdfads"

let ids := ids.getElems

let getElemFallback : TermElabM (Option Expr) := do
if ids.size ≠ 1 then
return none
match ids[0]! with
| `(elemIndex| $i:term) => elabTerm (← `(getElem $x $i (by get_elem_tactic))) none
| `(elemIndex| $i : $j) => elabTerm (← `(let a := $x; Array.toSubarray a $i $j)) none
| `(elemIndex| $i :) => elabTerm (← `(let a := $x; Array.toSubarray a $i a.size)) none
| `(elemIndex| : $j) => elabTerm (← `(let a := $x; Array.toSubarray a 0 $j)) none
| _ => return none


let x ← elabTerm x none
let X ← inferType x
let I ← mkFreshTypeMVar
let E ← mkFreshTypeMVar
let indexed := (mkAppN (← mkConstWithFreshMVarLevels ``Indexed) #[X, I, E])
let .some inst ← synthInstance? indexed
| if let .some xi ← getElemFallback then
return xi
else
throwError s!"`{← ppExpr x} : {← ppExpr X}` is not indexed type.
Please provide instance `Indexed {← ppExpr X} ?I ?E`."


| _ => throwError "asdf"


#checkfun i j => x[i] * y[j]

open Lean Elab Term in
elab (priority:=high+1) x:term noWs "[" i:term "]" : term => do
elabTerm (← `(GetElem.getElem $x $i True.intro)) none

0 comments on commit c92f377

Please sign in to comment.