diff --git a/SciLean/Core/FunctionTransformations/RevDeriv.lean b/SciLean/Core/FunctionTransformations/RevDeriv.lean index ffd8ce94..af999fb3 100644 --- a/SciLean/Core/FunctionTransformations/RevDeriv.lean +++ b/SciLean/Core/FunctionTransformations/RevDeriv.lean @@ -1247,3 +1247,68 @@ theorem Neg.neg.arg_a0.revDerivProjUpdate_rule (-ydf.1, fun i dy dx => ydf.2 i (-dy) dx) := by unfold revDerivProjUpdate; ftrans + + +-- HMul.hmul ------------------------------------------------------------------- +-------------------------------------------------------------------------------- +open ComplexConjugate + +@[ftrans] +theorem HMul.hMul.arg_a0a1.revDeriv_rule + (f g : X → K) + (hf : HasAdjDiff K f) (hg : HasAdjDiff K g) + : (revDeriv K fun x => f x * g x) + = + fun x => + let ydf := revDerivUpdate K f x + let zdg := revDeriv K g x + (ydf.1 * zdg.1, fun dx' => (ydf.2 (conj zdg.1 * dx') (zdg.2 (conj ydf.1 * dx')))) := +by + have ⟨_,_⟩ := hf + have ⟨_,_⟩ := hg + unfold revDerivUpdate; unfold revDeriv; simp; ftrans; ftrans; + simp [smul_push] + +@[ftrans] +theorem HMul.hMul.arg_a0a1.revDerivUpdate_rule + (f g : X → K) + (hf : HasAdjDiff K f) (hg : HasAdjDiff K g) + : (revDerivUpdate K fun x => f x * g x) + = + fun x => + let ydf := revDerivUpdate K f x + let zdg := revDerivUpdate K g x + (ydf.1 * zdg.1, fun dx' dx => (ydf.2 (conj zdg.1 * dx') (zdg.2 (conj ydf.1 * dx') dx))) := +by + have ⟨_,_⟩ := hf + have ⟨_,_⟩ := hg + unfold revDerivUpdate; unfold revDeriv; simp; ftrans; ftrans; + simp [smul_push,add_assoc] + +@[ftrans] +theorem HMul.hMul.arg_a0a1.revDerivProj_rule + (f g : X → K) + (hf : HasAdjDiff K f) (hg : HasAdjDiff K g) + : (revDerivProj K fun x => f x * g x) + = + fun x => + let ydf := revDerivUpdate K f x + let zdg := revDeriv K g x + (ydf.1 * zdg.1, fun _ dy => ydf.2 ((conj zdg.1)*dy) (zdg.2 (conj ydf.1* dy))) := +by + unfold revDerivProj + ftrans; simp[StructLike.oneHot, StructLike.make] + +@[ftrans] +theorem HMul.hMul.arg_a0a1.revDerivProjUpdate_rule + (f g : X → K) + (hf : HasAdjDiff K f) (hg : HasAdjDiff K g) + : (revDerivProjUpdate K fun x => f x * g x) + = + fun x => + let ydf := revDerivUpdate K f x + let zdg := revDerivUpdate K g x + (ydf.1 * zdg.1, fun _ dy dx => ydf.2 ((conj zdg.1)*dy) (zdg.2 (conj ydf.1* dy) dx)) := +by + unfold revDerivProjUpdate + ftrans; simp[revDerivUpdate,add_assoc]