From 8870eaa3174e859eee3384ff701189720ae3c4df Mon Sep 17 00:00:00 2001 From: lecopivo Date: Mon, 27 Nov 2023 18:08:42 -0500 Subject: [PATCH] bug fix in revDerivProj rules for Prod.fst/snd --- .../FunctionTransformations/RevDeriv.lean | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/SciLean/Core/FunctionTransformations/RevDeriv.lean b/SciLean/Core/FunctionTransformations/RevDeriv.lean index 36afe9f4..9013b7b8 100644 --- a/SciLean/Core/FunctionTransformations/RevDeriv.lean +++ b/SciLean/Core/FunctionTransformations/RevDeriv.lean @@ -196,7 +196,7 @@ theorem pi_rule : (revDeriv K fun (x : X) (i : I) => f x i) = fun x => - let xdf := revDerivProjUpdate K I f x + let xdf := revDerivProjUpdate K ((i:I)×Unit) f x (fun i => xdf.1 i, fun dy => Id.run do let mut dx : X := 0 @@ -993,11 +993,11 @@ by @[ftrans] theorem Prod.fst.arg_self.revDerivProj_rule - (f : W → X'×Y') (hf : HasAdjDiff K f) + (f : W → X'×Y) (hf : HasAdjDiff K f) : revDerivProj K Xi (fun x => (f x).1) = fun w => - let xydf := revDerivProj K (Xi⊕Yi) f w + let xydf := revDerivProj K (Xi⊕Unit) f w (xydf.1.1, fun i dxy => xydf.2 (.inl i) dxy) := by @@ -1010,11 +1010,11 @@ by @[ftrans] theorem Prod.fst.arg_self.revDerivProjUpdate_rule - (f : W → X'×Y') (hf : HasAdjDiff K f) + (f : W → X'×Y) (hf : HasAdjDiff K f) : revDerivProjUpdate K Xi (fun x => (f x).1) = fun w => - let xydf := revDerivProjUpdate K (Xi⊕Yi) f w + let xydf := revDerivProjUpdate K (Xi⊕Unit) f w (xydf.1.1, fun i dxy dw => xydf.2 (.inl i) dxy dw) := by @@ -1051,11 +1051,11 @@ by @[ftrans] theorem Prod.snd.arg_self.revDerivProj_rule - (f : W → X'×Y') (hf : HasAdjDiff K f) + (f : W → X×Y') (hf : HasAdjDiff K f) : revDerivProj K Yi (fun x => (f x).2) = fun w => - let xydf := revDerivProj K (Xi⊕Yi) f w + let xydf := revDerivProj K (Unit⊕Yi) f w (xydf.1.2, fun i dxy => xydf.2 (.inr i) dxy) := by @@ -1068,17 +1068,17 @@ by @[ftrans] theorem Prod.snd.arg_self.revDerivProjUpdate_rule - (f : W → X'×Y') (hf : HasAdjDiff K f) + (f : W → X×Y') (hf : HasAdjDiff K f) : revDerivProjUpdate K Yi (fun x => (f x).2) = fun w => - let xydf := revDerivProjUpdate K (Xi⊕Yi) f w + let xydf := revDerivProjUpdate K (Unit⊕Yi) f w (xydf.1.2, fun i dxy dw => xydf.2 (.inr i) dxy dw) := by unfold revDerivProjUpdate funext x; ftrans; simp - + -- HAdd.hAdd ------------------------------------------------------------------- --------------------------------------------------------------------------------