Skip to content

Commit

Permalink
min max on ArrayType
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Dec 6, 2023
1 parent 6975f5c commit 3010c78
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 0 deletions.
42 changes: 42 additions & 0 deletions SciLean/Data/ArrayType/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import SciLean.Util.SorryProof
import SciLean.Data.Index
import SciLean.Data.ListN
import SciLean.Data.StructType.Basic
import SciLean.Data.Function

namespace SciLean

Expand Down Expand Up @@ -278,6 +279,46 @@ theorem sum_introElem [EnumType Idx] [ArrayType Cont Idx Elem] [AddCommMonoid El
introElem fun i => ∑ j, f j i
:= sorry_proof


section UsefulFunctions


variable
[ArrayType Cont Idx Elem] [Index Idx]
[LT Elem] [∀ x y : Elem, Decidable (x < y)] [Inhabited Idx]

def argMaxCore (cont : Cont ) : Idx × Elem :=
Function.reduceD
(fun i => (i,cont[i]))
(fun (i,e) (i',e') => if e < e' then (i',e') else (i,e))
(default, cont[default])

def max (cont : Cont) : Elem :=
Function.reduceD
(fun i => cont[i])
(fun e e' => if e < e' then e' else e)
(cont[default])

def idxMax (cont : Cont) : Idx := (argMaxCore cont).1


def argMinCore (cont : Cont ) : Idx × Elem :=
Function.reduceD
(fun i => (i,cont[i]))
(fun (i,e) (i',e') => if e' < e then (i',e') else (i,e))
(default, cont[default])

def min (cont : Cont) : Elem :=
Function.reduceD
(fun i => cont[i])
(fun e e' => if e < e' then e' else e)
(cont[default])

def idxMin (cont : Cont) : Idx := (argMinCore cont).1

end UsefulFunctions


end ArrayType


Expand All @@ -302,3 +343,4 @@ namespace ArrayType
else y[⟨i.1-n, sorry_proof⟩]

end ArrayType

53 changes: 53 additions & 0 deletions SciLean/Data/ArrayType/Properties.lean
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,24 @@ theorem ArrayType.map.arg_farr.HasAdjDiff_rule
(hf : HasAdjDiff K (fun (xe : X×Elem) => f xe.1 xe.2)) (harr : HasAdjDiff K arr)
: HasAdjDiff K (fun x => map (f x) (arr x)) := sorry_proof

@[ftrans]
theorem ArrayType.map.arg_farr.revDeriv_rule
(f : X → Elem → Elem) (arr : X → Cont)
(hf : HasAdjDiff K (fun (x,e) => f x e)) (harr : HasAdjDiff K arr)
: revDeriv K (fun x => map (f x) (arr x))
=
fun x =>
let fdf := revDerivUpdate K (fun ((x,e) : X×Elem) => f x e)
let ada := revDerivUpdate K arr x
let a := ada.1
(map (f x) a,
fun da =>
let (dx,da) := Function.repeatIdx (init:=((0 : X),da))
(fun (i : Idx) dxa =>
let dxai := (fdf (x,a[i])).2 dxa.2[i] (dxa.1,0)
(dxai.1, setElem dxa.2 i dxai.2))
ada.2 da dx) := sorry_proof


@[ftrans]
theorem ArrayType.map.arg_arr.revDeriv_rule
Expand Down Expand Up @@ -957,6 +975,41 @@ theorem ArrayType.map.arg_arr.revDerivUpdate_rule
let da := mapIdx (fun i dai => let df := (fdf a[i]).2; df dai) da
ada.2 da dx) := sorry_proof

--------------------------------------------------------------------------------

@[fprop]
theorem ArrayType.max.arg_cont.HasAdjDiff_rule
[LT Elem] [∀ x y : Elem, Decidable (x < y)] [Inhabited Idx]
(arr : X → Cont)
(hf : HasAdjDiff K arr) (hfalse : fpropParam False)
: HasAdjDiff K (fun x => max (arr x)) := sorry_proof


@[ftrans]
theorem ArrayType.max.arg_arr.revDeriv_rule
[LT Elem] [∀ x y : Elem, Decidable (x < y)] [Inhabited Idx]
(arr : X → Cont)
(hf : HasAdjDiff K arr) (hfalse : fpropParam False)
: revDeriv K (fun x => max (arr x))
=
fun x =>
let i := idxMax (arr x)
let fdf := revDerivProj K Idx arr x
(fdf.1[i], fun dei => fdf.2 i dei) := sorry_proof


@[ftrans]
theorem ArrayType.max.arg_arr.revDerivUpdate_rule
[LT Elem] [∀ x y : Elem, Decidable (x < y)] [Inhabited Idx]
(arr : X → Cont)
(hf : HasAdjDiff K arr) (hfalse : fpropParam False)
: revDerivUpdate K (fun x => max (arr x))
=
fun x =>
let i := idxMax (arr x)
let fdf := revDerivProjUpdate K Idx arr x
(fdf.1[i], fun dei dx => fdf.2 i dei dx) := sorry_proof


-- @[ftrans]
-- theorem ArrayType.map.arg_farr.revDeriv_rule
Expand Down

0 comments on commit 3010c78

Please sign in to comment.