Skip to content

Commit

Permalink
define Function.modify and use it in StructType
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Dec 1, 2023
1 parent 741eb98 commit ab1eb62
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 105 deletions.
2 changes: 1 addition & 1 deletion SciLean/Core/FunctionTransformations/InvFun.lean
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def discharger (e : Expr) : SimpM (Option Expr) := do
let config : FProp.Config := {}
let state : FProp.State := { cache := cache }
let (proof?, state) ← FProp.fprop e |>.run config |>.run state
modify (fun simpState => { simpState with cache := state.cache })
_root_.modify (fun simpState => { simpState with cache := state.cache })
if proof?.isSome then
return proof?
else
Expand Down
123 changes: 49 additions & 74 deletions SciLean/Core/FunctionTransformations/RevDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ set_option linter.unusedVariables false

namespace SciLean

-- set_option linter.ftransSsaRhs true

variable
(K I : Type _) [IsROrC K]
{X : Type _} [SemiInnerProductSpace K X]
Expand Down Expand Up @@ -148,10 +146,10 @@ theorem proj_rule (i : I)
: revDeriv K (fun (x : (i:I) → EI i) => x i)
=
fun x =>
(x i, fun dxi j => if h : i=j then h ▸ dxi else 0) :=
(x i, fun dxi => oneHot i dxi) :=
by
unfold revDeriv
funext _; ftrans; ftrans
funext _; ftrans; ftrans; simp[oneHot]
variable (I)
variable {EI}

Expand Down Expand Up @@ -199,13 +197,10 @@ theorem pi_rule
: (revDeriv K fun (x : X) (i : I) => f x i)
=
fun x =>
let xdf := revDerivProjUpdate K I f x
(fun i => xdf.1 i,
fun dy => Id.run do
let mut dx : X := 0
for i in fullRange I do
dx := xdf.2 i (dy i) dx
dx) :=
let xdf := fun i => revDerivUpdate K (f · i) x
(fun i => (xdf i).1,
fun dy =>
Function.repeatIdx (fun (i : I) dx => (xdf i).2 (dy i) dx) 0) :=
by
have _ := fun i => (hf i).1
have _ := fun i => (hf i).2
Expand Down Expand Up @@ -245,12 +240,10 @@ theorem proj_rule (i : I)
: revDerivUpdate K (fun (x : (i:I) → EI i) => x i)
=
fun x =>
(x i, fun dxi dx j => if h : i=j then dx j + h ▸ dxi else dx j) :=
(x i, fun dxi dx => structModify i (fun dxi' => dxi' + dxi) dx) :=
by
unfold revDerivUpdate
simp [revDeriv.proj_rule]
funext _; ftrans; ftrans;
simp; funext dxi dx j; simp; sorry_proof
variable (I)
variable {EI}

Expand Down Expand Up @@ -293,13 +286,10 @@ theorem pi_rule
: (revDerivUpdate K fun (x : X) (i : I) => f x i)
=
fun x =>
let xdf := revDerivProjUpdate K I f x
(fun i => xdf.1 i,
fun dy dx => Id.run do
let mut dx := dx
for i in fullRange I do
dx := xdf.2 i (dy i) dx
dx) :=
let xdf := fun i => revDerivUpdate K (f · i) x
(fun i => (xdf i).1,
fun dy dx =>
Function.repeatIdx (fun (i : I) dx => (xdf i).2 (dy i) dx) dx) :=
by
unfold revDerivUpdate
simp [revDeriv.pi_rule _ _ hf, revDerivUpdate]
Expand Down Expand Up @@ -342,14 +332,21 @@ variable {Y}
theorem proj_rule [DecidableEq I] (i : ι)
: revDerivProj K I (fun (f : ι → E) => f i)
=
fun f =>
(f i, fun j dxj i' =>
if i=i' then
oneHot j dxj
else
0) :=
fun f : ι → E =>
(f i, fun j dxj => oneHot (X:=ι→E) (I:=ι×I) (i,j) dxj) :=
by
unfold revDerivProj; simp[revDeriv.proj_rule]
unfold revDerivProj; simp[revDeriv.proj_rule, oneHot]
funext x; simp; funext j de i'
if h:i=i' then
subst h
simp; congr; funext j'
if h':j=j' then
subst h'
simp
else
simp[h']
else
simp[h]

theorem comp_rule
(f : Y → E) (g : X → Y)
Expand Down Expand Up @@ -386,13 +383,10 @@ theorem pi_rule
: (revDerivProj K Unit fun (x : X) (i : ι) => f x i)
=
fun x =>
let ydf := revDerivProjUpdate K ι f x
(fun i => ydf.1 i,
fun _ df => Id.run do
let mut dx : X := 0
for i in fullRange ι do
dx := ydf.2 i (df i) dx
dx) :=
let ydf := fun i => revDerivUpdate K (f · i) x
(fun i => (ydf i).1,
fun _ df =>
Function.repeatIdx (fun i dx => (ydf i).2 (df i) dx) (0 : X)) :=
by
sorry_proof

Expand Down Expand Up @@ -447,10 +441,11 @@ by
funext j dxj f i'
apply structExt (I:=I)
intro j'
if h :i=i' then
if h :i'=i then
subst h; simp
else
simp[h]
have h' : i≠i' := by intro h''; simp[h''] at h
simp[h,h',Function.update]


theorem comp_rule
Expand Down Expand Up @@ -492,13 +487,9 @@ theorem pi_rule
: (revDerivProjUpdate K Unit fun (x : X) (i : ι) => f x i)
=
fun x =>
let ydf := revDerivProjUpdate K ι f x
(fun i => ydf.1 i,
fun _ df dx => Id.run do
let mut dx : X := dx
for i in fullRange ι do
dx := ydf.2 i (df i) dx
dx) :=
let ydf := fun i => revDerivUpdate K (f · i) x
(fun i => (ydf i).1,
fun _ df dx => Function.repeatIdx (fun i dx => (ydf i).2 (df i) dx) dx) :=
by
conv => lhs; unfold revDerivProjUpdate
simp [revDerivProj.pi_rule _ _ hf,add_assoc,add_comm]
Expand Down Expand Up @@ -1614,13 +1605,9 @@ theorem SciLean.EnumType.sum.arg_f.revDeriv_rule {ι : Type} [EnumType ι]
: revDeriv K (fun x => ∑ i, f x i)
=
fun x =>
let ydf := revDerivProjUpdate K ι f x
(∑ i, ydf.1 i,
fun dy => Id.run do
let mut dx := 0
for i in fullRange ι do
dx := ydf.2 i dy dx
dx) :=
let ydf := fun i => revDerivUpdate K (f · i) x
(∑ i, (ydf i).1,
fun dy => Function.repeatIdx (fun i dx => (ydf i).2 dy dx) 0) :=
by
have _ := fun i => (hf i).1
have _ := fun i => (hf i).2
Expand All @@ -1634,13 +1621,9 @@ theorem SciLean.EnumType.sum.arg_f.revDerivUpdate_rule {ι : Type} [EnumType ι]
: revDerivUpdate K (fun x => ∑ i, f x i)
=
fun x =>
let ydf := revDerivProjUpdate K ι f x
(∑ i, ydf.1 i,
fun dy dx => Id.run do
let mut dx := dx
for i in fullRange ι do
dx := ydf.2 i dy dx
dx) :=
let ydf := fun i => revDerivUpdate K (f · i) x
(∑ i, (ydf i).1,
fun dy dx => Function.repeatIdx (fun i dx => (ydf i).2 dy dx) dx) :=
by
simp[revDerivUpdate]
ftrans
Expand All @@ -1653,15 +1636,11 @@ theorem SciLean.EnumType.sum.arg_f.revDerivProj_rule {ι : Type} [EnumType ι]
: revDerivProj K Yi (fun x => ∑ i, f x i)
=
fun x =>
let ydf := revDerivProjUpdate K (ι×Yi) f x
(∑ i, ydf.1 i,
fun i dy => Id.run do
let mut dx : X := 0
for j in fullRange ι do
dx := ydf.2 (j,i) dy dx
dx) :=
let ydf := fun i => revDerivProjUpdate K Yi (f · i) x
(∑ i, (ydf i).1,
fun j dy => Function.repeatIdx (fun (i : ι) dx => (ydf i).2 j dy dx) 0) :=
by
funext; simp[revDerivProj]; ftrans; simp; sorry_proof
funext; simp[revDerivProj]; ftrans; sorry_proof


@[ftrans]
Expand All @@ -1670,15 +1649,11 @@ theorem SciLean.EnumType.sum.arg_f.revDerivProjUpdate_rule {ι : Type} [EnumType
: revDerivProjUpdate K Yi (fun x => ∑ i, f x i)
=
fun x =>
let ydf := revDerivProjUpdate K (ι×Yi) f x
(∑ i, ydf.1 i,
fun i dy dx => Id.run do
let mut dx : X := dx
for j in fullRange ι do
dx := ydf.2 (j,i) dy dx
dx) :=
let ydf := fun i => revDerivProjUpdate K Yi (f · i) x
(∑ i, (ydf i).1,
fun j dy dx => Function.repeatIdx (fun (i : ι) dx => (ydf i).2 j dy dx) dx) :=
by
funext; simp[revDerivProjUpdate]; ftrans; simp; sorry_proof
funext; simp[revDerivProjUpdate]; ftrans; sorry_proof


-- d/ite -----------------------------------------------------------------------
Expand Down
11 changes: 5 additions & 6 deletions SciLean/Data/ArrayType/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -127,19 +127,18 @@ theorem introElem_getElem [ArrayType Cont Idx Elem] (cont : Cont)

-- TODO: Make an inplace modification
-- Maybe turn this into a class and this is a default implementation
def modifyElem [GetElem Cont Idx Elem λ _ _ => True] [SetElem Cont Idx Elem]
def _root_.SciLean.modifyElem [GetElem Cont Idx Elem λ _ _ => True] [SetElem Cont Idx Elem]
(arr : Cont) (i : Idx) (f : Elem → Elem) : Cont :=
structModify i f arr
let xi := arr[i]
setElem arr i (f xi)

set_option trace.Meta.Tactic.simp.discharge true in
set_option trace.Meta.Tactic.simp.unify true in
@[simp]
theorem getElem_modifyElem_eq [ArrayType Cont Idx Elem] (cont : Cont) (idx : Idx) (f : Elem → Elem)
: (modifyElem cont idx f)[idx] = f cont[idx] := by simp[getElem_structProj,modifyElem]; done
: (modifyElem cont idx f)[idx] = f cont[idx] := by simp[getElem_structProj,modifyElem,setElem_structModify]; done

@[simp]
theorem getElem_modifyElem_neq [inst : ArrayType Cont Idx Elem] (arr : Cont) (i j : Idx) (f : Elem → Elem)
: i ≠ j → (modifyElem arr i f)[j] = arr[j] := by intro h; simp [h,modifyElem, getElem_structProj,modifyElem]; done
: i ≠ j → (modifyElem arr i f)[j] = arr[j] := by intro h; simp [h,modifyElem, getElem_structProj,modifyElem,setElem_structModify]; done


-- Maybe turn this into a class and this is a default implementation
Expand Down
2 changes: 1 addition & 1 deletion SciLean/Data/DataArray/DataArray.lean
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ instance : IntroElem (DataArrayN α ι) ι α where
instance : StructType (DataArrayN α ι) ι (fun _ => α) where
structProj x i := x[i]
structMake f := introElem f
structModify i f x := setElem x i (f x[i])
structModify i f x := modifyElem x i f
left_inv := sorry_proof
right_inv := sorry_proof
structProj_structModify := sorry_proof
Expand Down
46 changes: 46 additions & 0 deletions SciLean/Data/Function.lean
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,49 @@ def Function.reduceD (f : ι → α) (op : α → α → α) (default : α) : α

abbrev Function.reduce [Inhabited α] (f : ι → α) (op : α → α → α) : α :=
f.reduceD op default


section FunctionModify

variable {α : Sort u} {β : α → Sort v} {α' : Sort w} [DecidableEq α] [DecidableEq α']

/-- Similar to `Function.update` but `g` specifies how to change the value at `a'`. -/
def Function.modify (f : ∀ a, β a) (a' : α) (g : β a' → β a') (a : α) : β a :=
Function.update f a' (g (f a')) a

@[simp]
theorem Function.modify_same (a : α) (g : β a → β a) (f : ∀ a, β a) : modify f a g a = g (f a) :=
dif_pos rfl

@[simp]
theorem Function.modify_noteq {a a' : α} (h : a ≠ a') (g : β a' → β a') (f : ∀ a, β a) : modify f a' g a = f a :=
dif_neg h

end FunctionModify


def Function.repeatIdx (f : ι → α → α) (init : α) : α := Id.run do
let mut x := init
for i in fullRange ι do
x := f i x
x

def Function.repeat (n : Nat) (f : α → α) (init : α) : α :=
repeatIdx (fun (_ : Fin n) x => f x) init


@[simp]
theorem Function.repeatIdx_update {α : Type _} (f : ι → α → α) (g : ι → α)
: repeatIdx (fun i g' => Function.update g' i (f i (g' i))) g
=
fun i => f i (g i) := sorry_proof

/-- Specialized formulation of `Function.repeatIdx_update` which is sometimes more
succesfull with unification -/
@[simp]
theorem Function.repeatIdx_update' {α : Type _} (f : ι → α) (g : ι → α) (op : α → α → α)
: repeatIdx (fun i g' => Function.update g' i (op (g' i) (f i))) g
=
fun i => op (g i) (f i) :=
by
apply Function.repeatIdx_update (f := fun i x => op x (f i))
28 changes: 6 additions & 22 deletions SciLean/Data/StructType/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,7 @@ instance (priority:=low+1) instStrucTypePiSimple
: StructType (∀ i, E i) I E where
structProj := fun f i => f i
structMake := fun f i => f i
structModify := fun i g f i' =>
if h : i'=i then
h ▸ (g (f i))
else
(f i')
structModify := fun i g f => Function.modify f i g
left_inv := by simp[LeftInverse]
right_inv := by simp[Function.RightInverse, LeftInverse]
structProj_structModify := by simp
Expand All @@ -116,11 +112,7 @@ instance (priority:=low+1) instStrucTypePi
: StructType (∀ i, E i) ((i : I) × (J i)) (fun ⟨i,j⟩ => EJ i j) where
structProj := fun f ⟨i,j⟩ => StructType.structProj (f i) j
structMake := fun f i => StructType.structMake fun j => f ⟨i,j⟩
structModify := fun ⟨i,j⟩ f x i' =>
if h : i'=i then
StructType.structModify (I:=J i') (h▸j) (h▸f) (x i')
else
(x i')
structModify := fun ⟨i,j⟩ f x => Function.modify x i (StructType.structModify (I:=J i) j f)
left_inv := by simp[LeftInverse]
right_inv := by simp[Function.RightInverse, LeftInverse]
structProj_structModify := by simp
Expand All @@ -140,17 +132,13 @@ instance instStrucTypeArrowSimple
: StructType (J → E) J (fun _ => E) where
structProj := fun f j => f j
structMake := fun f j => f j
structModify := fun j g f j' =>
if j=j' then
g (f j')
else
(f j')
structModify := fun j g f => Function.modify f j g
left_inv := by simp[LeftInverse]
right_inv := by simp[Function.RightInverse, LeftInverse]
structProj_structModify := by simp
structProj_structModify' := by
intro j j' f x H; simp
if h: j = j' then
if h: j' = j then
simp [h] at H
else
simp[h]
Expand All @@ -161,17 +149,13 @@ instance instStrucTypeArrow
: StructType (J → E) (J×I) (fun (_,i) => EI i) where
structProj := fun f (j,i) => StructType.structProj (f j) i
structMake := fun f j => StructType.structMake fun i => f (j,i)
structModify := fun (j,i) f x j' =>
if j=j' then
StructType.structModify i f (x j)
else
(x j')
structModify := fun (j,i) f x => Function.modify x j (StructType.structModify i f)
left_inv := by simp[LeftInverse]
right_inv := by simp[Function.RightInverse, LeftInverse]
structProj_structModify := by simp
structProj_structModify' := by
intro (j,i) (j',i') f x H; simp
if h: j = j' then
if h: j'=j then
subst h
if h': i=i' then
simp[h'] at H
Expand Down
5 changes: 4 additions & 1 deletion SciLean/Tactic/FTrans/Simp.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ import Mathlib.Algebra.SMulWithZero
namespace SciLean

-- basic algebraic operations
attribute [ftrans_simp] Prod.mk_add_mk Prod.mk_mul_mk Prod.smul_mk Prod.mk_sub_mk Prod.neg_mk Prod.vadd_mk add_zero zero_add sub_zero zero_sub sub_self neg_zero mul_zero zero_mul zero_smul smul_zero smul_eq_mul smul_neg eq_self iff_self mul_one one_mul one_smul Prod.fst_zero Prod.snd_zero
attribute [ftrans_simp] add_zero zero_add sub_zero zero_sub sub_self neg_zero mul_zero zero_mul zero_smul smul_zero smul_eq_mul smul_neg eq_self iff_self mul_one one_mul one_smul

-- simp theorems for `Prod`
attribute [ftrans_simp] Prod.mk.eta Prod.fst_zero Prod.snd_zero Prod.mk_add_mk Prod.mk_mul_mk Prod.smul_mk Prod.mk_sub_mk Prod.neg_mk Prod.vadd_mk

-- simp theorems for `Equiv`
attribute [ftrans_simp] Equiv.invFun_as_coe Equiv.symm_symm
Expand Down

0 comments on commit ab1eb62

Please sign in to comment.