diff --git a/SciLean/Data/ArrayType/Basic.lean b/SciLean/Data/ArrayType/Basic.lean index e9cd9196..7d1bc2f4 100644 --- a/SciLean/Data/ArrayType/Basic.lean +++ b/SciLean/Data/ArrayType/Basic.lean @@ -2,6 +2,7 @@ import SciLean.Util.SorryProof import SciLean.Data.Index import SciLean.Data.ListN import SciLean.Data.StructType.Basic +import SciLean.Data.Function namespace SciLean @@ -278,6 +279,46 @@ theorem sum_introElem [EnumType Idx] [ArrayType Cont Idx Elem] [AddCommMonoid El introElem fun i => ∑ j, f j i := sorry_proof + +section UsefulFunctions + + +variable + [ArrayType Cont Idx Elem] [Index Idx] + [LT Elem] [∀ x y : Elem, Decidable (x < y)] [Inhabited Idx] + +def argMaxCore (cont : Cont ) : Idx × Elem := + Function.reduceD + (fun i => (i,cont[i])) + (fun (i,e) (i',e') => if e < e' then (i',e') else (i,e)) + (default, cont[default]) + +def max (cont : Cont) : Elem := + Function.reduceD + (fun i => cont[i]) + (fun e e' => if e < e' then e' else e) + (cont[default]) + +def idxMax (cont : Cont) : Idx := (argMaxCore cont).1 + + +def argMinCore (cont : Cont ) : Idx × Elem := + Function.reduceD + (fun i => (i,cont[i])) + (fun (i,e) (i',e') => if e' < e then (i',e') else (i,e)) + (default, cont[default]) + +def min (cont : Cont) : Elem := + Function.reduceD + (fun i => cont[i]) + (fun e e' => if e < e' then e' else e) + (cont[default]) + +def idxMin (cont : Cont) : Idx := (argMinCore cont).1 + +end UsefulFunctions + + end ArrayType @@ -302,3 +343,4 @@ namespace ArrayType else y[⟨i.1-n, sorry_proof⟩] end ArrayType + diff --git a/SciLean/Data/ArrayType/Properties.lean b/SciLean/Data/ArrayType/Properties.lean index bf7d2f09..58709c3e 100644 --- a/SciLean/Data/ArrayType/Properties.lean +++ b/SciLean/Data/ArrayType/Properties.lean @@ -925,6 +925,24 @@ theorem ArrayType.map.arg_farr.HasAdjDiff_rule (hf : HasAdjDiff K (fun (xe : X×Elem) => f xe.1 xe.2)) (harr : HasAdjDiff K arr) : HasAdjDiff K (fun x => map (f x) (arr x)) := sorry_proof +@[ftrans] +theorem ArrayType.map.arg_farr.revDeriv_rule + (f : X → Elem → Elem) (arr : X → Cont) + (hf : HasAdjDiff K (fun (x,e) => f x e)) (harr : HasAdjDiff K arr) + : revDeriv K (fun x => map (f x) (arr x)) + = + fun x => + let fdf := revDerivUpdate K (fun ((x,e) : X×Elem) => f x e) + let ada := revDerivUpdate K arr x + let a := ada.1 + (map (f x) a, + fun da => + let (dx,da) := Function.repeatIdx (init:=((0 : X),da)) + (fun (i : Idx) dxa => + let dxai := (fdf (x,a[i])).2 dxa.2[i] (dxa.1,0) + (dxai.1, setElem dxa.2 i dxai.2)) + ada.2 da dx) := sorry_proof + @[ftrans] theorem ArrayType.map.arg_arr.revDeriv_rule @@ -957,6 +975,41 @@ theorem ArrayType.map.arg_arr.revDerivUpdate_rule let da := mapIdx (fun i dai => let df := (fdf a[i]).2; df dai) da ada.2 da dx) := sorry_proof +-------------------------------------------------------------------------------- + +@[fprop] +theorem ArrayType.max.arg_cont.HasAdjDiff_rule + [LT Elem] [∀ x y : Elem, Decidable (x < y)] [Inhabited Idx] + (arr : X → Cont) + (hf : HasAdjDiff K arr) (hfalse : fpropParam False) + : HasAdjDiff K (fun x => max (arr x)) := sorry_proof + + +@[ftrans] +theorem ArrayType.max.arg_arr.revDeriv_rule + [LT Elem] [∀ x y : Elem, Decidable (x < y)] [Inhabited Idx] + (arr : X → Cont) + (hf : HasAdjDiff K arr) (hfalse : fpropParam False) + : revDeriv K (fun x => max (arr x)) + = + fun x => + let i := idxMax (arr x) + let fdf := revDerivProj K Idx arr x + (fdf.1[i], fun dei => fdf.2 i dei) := sorry_proof + + +@[ftrans] +theorem ArrayType.max.arg_arr.revDerivUpdate_rule + [LT Elem] [∀ x y : Elem, Decidable (x < y)] [Inhabited Idx] + (arr : X → Cont) + (hf : HasAdjDiff K arr) (hfalse : fpropParam False) + : revDerivUpdate K (fun x => max (arr x)) + = + fun x => + let i := idxMax (arr x) + let fdf := revDerivProjUpdate K Idx arr x + (fdf.1[i], fun dei dx => fdf.2 i dei dx) := sorry_proof + -- @[ftrans] -- theorem ArrayType.map.arg_farr.revDeriv_rule