Skip to content

Commit

Permalink
Merge pull request RemyDegenne#121 from LorenzoLuccioli/LL/DPI
Browse files Browse the repository at this point in the history
Add the DPI for the f-divergence in the general case
  • Loading branch information
RemyDegenne authored Aug 6, 2024
2 parents b501843 + c8050ed commit b209c28
Show file tree
Hide file tree
Showing 5 changed files with 423 additions and 16 deletions.
66 changes: 63 additions & 3 deletions TestingLowerBounds/CurvatureMeasure.lean
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,75 @@ import Mathlib.Analysis.SpecialFunctions.Gamma.BohrMollerup

open MeasureTheory Set StieltjesFunction ProbabilityTheory

namespace ConvexOn

variable {𝒳 : Type*} {m𝒳 : MeasurableSpace 𝒳} {μ ν : Measure 𝒳} {f g : ℝ → ℝ} {β γ x t : ℝ}

namespace StieltjesFunction

open Set Filter Function ENNReal NNReal Topology MeasureTheory
open ENNReal (ofReal)

variable (f : StieltjesFunction)

--PR this to mathlib, just before `StieltjesFunction.measure_const`
@[simp]
lemma _root_.StieltjesFunction.measure_zero : StieltjesFunction.measure 0 = 0 :=
lemma measure_zero : StieltjesFunction.measure 0 = 0 :=
Measure.ext_of_Ioc _ _ (fun _ _ _ ↦ by simp; rfl)


--PR this to mathlib, just after `StieltjesFunction.measure_Iic`
lemma measure_Iio {l : ℝ} (hf : Tendsto f atBot (𝓝 l)) (x : ℝ) :
f.measure (Iio x) = ofReal (leftLim f x - l) := by
rw [← Iic_diff_right, measure_diff _ (measurableSet_singleton x), measure_singleton,
f.measure_Iic hf, ← ofReal_sub _ (sub_nonneg.mpr <| Monotone.leftLim_le f.mono' (le_refl _))]
<;> simp

--PR this to mathlib, just after `StieltjesFunction.measure_Ici`
lemma measure_Ioi {l : ℝ} (hf : Tendsto f atTop (𝓝 l)) (x : ℝ) :
f.measure (Ioi x) = ofReal (l - f x) := by
rw [← Ici_diff_left, measure_diff _ (measurableSet_singleton x), measure_singleton,
f.measure_Ici hf, ← ofReal_sub _ (sub_nonneg.mpr <| Monotone.leftLim_le f.mono' (le_refl _))]
<;> simp

--PR this and the following lemmas to mathlib, just after `StieltjesFunction.measure_univ`
lemma measure_Ioi_of_tendsto_atTop_atTop (hf : Tendsto f atTop atTop) (x : ℝ) :
f.measure (Ioi x) = ∞ := by
refine ENNReal.eq_top_of_forall_nnreal_le fun r ↦ ?_
obtain ⟨N, hN⟩ := eventually_atTop.mp (tendsto_atTop.mp hf (r + f x))
exact (f.measure_Ioc x (max x N) ▸ ENNReal.coe_nnreal_eq r ▸ (ENNReal.ofReal_le_ofReal <|
le_tsub_of_add_le_right <| hN _ (le_max_right x N))).trans (measure_mono Ioc_subset_Ioi_self)

lemma measure_Ici_of_tendsto_atTop_atTop (hf : Tendsto f atTop atTop) (x : ℝ) :
f.measure (Ici x) = ∞ := by
rw [← top_le_iff, ← f.measure_Ioi_of_tendsto_atTop_atTop hf x]
exact measure_mono Ioi_subset_Ici_self

lemma measure_Iic_of_tendsto_atBot_atBot (hf : Tendsto f atBot atBot) (x : ℝ) :
f.measure (Iic x) = ∞ := by
refine ENNReal.eq_top_of_forall_nnreal_le fun r ↦ ?_
obtain ⟨N, hN⟩ := eventually_atBot.mp (tendsto_atBot.mp hf (f x - r))
exact (f.measure_Ioc (min x N) x ▸ ENNReal.coe_nnreal_eq r ▸ (ENNReal.ofReal_le_ofReal <|
le_sub_comm.mp <| hN _ (min_le_right x N))).trans (measure_mono Ioc_subset_Iic_self)

lemma measure_Iio_of_tendsto_atBot_atBot (hf : Tendsto f atBot atBot) (x : ℝ) :
f.measure (Iio x) = ∞ := by
rw [← top_le_iff, ← f.measure_Iic_of_tendsto_atBot_atBot hf (x - 1)]
exact measure_mono <| Set.Iic_subset_Iio.mpr <| sub_one_lt x

lemma measure_univ_of_tendsto_atTop_atTop (hf : Tendsto f atTop atTop) :
f.measure univ = ∞ := by
rw [← top_le_iff, ← f.measure_Ioi_of_tendsto_atTop_atTop hf 0]
exact measure_mono fun _ _ ↦ trivial

lemma measure_univ_of_tendsto_atBot_atBot (hf : Tendsto f atBot atBot) :
f.measure univ = ∞ := by
rw [← top_le_iff, ← f.measure_Iio_of_tendsto_atBot_atBot hf 0]
exact measure_mono fun _ _ ↦ trivial


end StieltjesFunction

namespace ConvexOn

open Classical in
/-- The curvature measure induced by a convex function. It is defined as the only measure that has
the right derivative of the function as a CDF.
Expand Down
17 changes: 17 additions & 0 deletions TestingLowerBounds/DerivAtTop.lean
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,16 @@ lemma ConvexOn.derivAtTop_eq_iff {y : EReal} (hf : ConvexOn ℝ (Ici 0) f) :
derivAtTop f = y ↔ Tendsto (fun x ↦ (rightDeriv f x : EReal)) atTop (𝓝 y) :=
hf.rightDeriv_mono'.derivAtTop_eq_iff

lemma MonotoneOn.derivAtTop_eq_top_iff (hf : MonotoneOn (rightDeriv f) (Ioi 0)) :
derivAtTop f = ⊤ ↔ Tendsto (rightDeriv f) atTop atTop := by
refine ⟨fun h ↦ ?_, fun h ↦ derivAtTop_of_tendsto_atTop h⟩
exact EReal.tendsto_toReal_atTop.comp (tendsto_nhdsWithin_of_tendsto_nhds_of_eventually_within _
(h ▸ hf.tendsto_derivAtTop) (eventually_of_forall fun _ ↦ EReal.coe_ne_top _))

lemma ConvexOn.derivAtTop_eq_top_iff (hf : ConvexOn ℝ (Ici 0) f) :
derivAtTop f = ⊤ ↔ Tendsto (rightDeriv f) atTop atTop :=
hf.rightDeriv_mono'.derivAtTop_eq_top_iff

lemma MonotoneOn.derivAtTop_ne_bot (hf : MonotoneOn (rightDeriv f) (Ioi 0)) : derivAtTop f ≠ ⊥ := by
intro h_eq
rw [hf.derivAtTop_eq_iff, ← tendsto_extendBotLtOne_rightDeriv_iff] at h_eq
Expand Down Expand Up @@ -248,6 +258,13 @@ lemma rightDeriv_le_toReal_derivAtTop (h_cvx : ConvexOn ℝ (Ici 0) f) (h : deri
· exact mem_Ioi.mpr (hx.trans_le ((le_max_right _ _).trans hy))
· exact (le_max_right _ _).trans hy

lemma rightDeriv_le_derivAtTop (h_cvx : ConvexOn ℝ (Ici 0) f) (hx : 0 < x) :
rightDeriv f x ≤ derivAtTop f := by
by_cases h : derivAtTop f = ⊤
· exact h ▸ le_top
· rw [← EReal.coe_toReal h h_cvx.derivAtTop_ne_bot, EReal.coe_le_coe_iff]
exact rightDeriv_le_toReal_derivAtTop h_cvx h hx

lemma slope_le_derivAtTop (h_cvx : ConvexOn ℝ (Ici 0) f)
(h : derivAtTop f ≠ ⊤) {x y : ℝ} (hx : 0 ≤ x) (hxy : x < y) :
(f y - f x) / (y - x) ≤ (derivAtTop f).toReal :=
Expand Down
203 changes: 191 additions & 12 deletions TestingLowerBounds/Divergences/StatInfo.lean
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/-
Copyright (c) 2024 Rémy Degenne. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Rémy Degenne
Authors: Rémy Degenne, Lorenzo Luccioli
-/
import TestingLowerBounds.CurvatureMeasure
import TestingLowerBounds.StatInfoFun
Expand All @@ -10,6 +10,7 @@ import TestingLowerBounds.FDiv.Basic
import TestingLowerBounds.Testing.Binary
import Mathlib.MeasureTheory.Constructions.Prod.Integral
import TestingLowerBounds.ForMathlib.SetIntegral
import TestingLowerBounds.ForMathlib.Indicator

/-!
# Statistical information
Expand Down Expand Up @@ -869,9 +870,8 @@ lemma fDiv_eq_integral_fDiv_statInfoFun_of_absolutelyContinuous
((Measure.integrable_toReal_rnDeriv.sub (integrable_const 1)).const_mul _)
all_goals exact ENNReal.toReal_toEReal_of_ne_top (measure_ne_top _ _)

lemma fDiv_eq_lintegral_fDiv_statInfoFun_of_absolutelyContinuous
[IsFiniteMeasure μ] [IsFiniteMeasure ν] (hf_cvx : ConvexOn ℝ univ f) (hf_cont : Continuous f)
(h_ac : μ ≪ ν) :
lemma fDiv_eq_lintegral_fDiv_statInfoFun_of_absolutelyContinuous [IsFiniteMeasure μ]
[IsFiniteMeasure ν] (hf_cvx : ConvexOn ℝ univ f) (hf_cont : Continuous f) (h_ac : μ ≪ ν) :
fDiv f μ ν = ∫⁻ x, (fDiv (statInfoFun 1 x) μ ν).toENNReal ∂(curvatureMeasure f)
+ f 1 * ν univ + rightDeriv f 1 * (μ univ - ν univ) := by
by_cases h_int : Integrable (fun x ↦ f ((∂μ/∂ν) x).toReal) ν
Expand Down Expand Up @@ -907,6 +907,151 @@ lemma fDiv_eq_lintegral_fDiv_statInfoFun_of_absolutelyContinuous
(by fun_prop) |>.ereal_toENNReal.aemeasurable
· exact eventually_of_forall fun _ ↦ EReal.toENNReal_ne_top_iff.mpr fDiv_statInfoFun_ne_top

lemma lintegral_statInfoFun_one_zero (hf_cvx : ConvexOn ℝ univ f) (hf_cont : Continuous f) :
∫⁻ x, ENNReal.ofReal (statInfoFun 1 x 0) ∂curvatureMeasure f
= (f 0).toEReal - f 1 + rightDeriv f 1 := by
norm_cast
have := convex_taylor hf_cvx hf_cont (a := 1) (b := 0)
simp only [zero_sub, mul_neg, mul_one, sub_neg_eq_add] at this
rw [this, intervalIntegral.integral_of_ge (zero_le_one' _), integral_neg, neg_neg,
← ofReal_integral_eq_lintegral_ofReal _
(eventually_of_forall fun x ↦ statInfoFun_nonneg 1 x 0)]
rotate_left
· refine Integrable.mono' (g := (Ioc 0 1).indicator 1) ?_
measurable_statInfoFun2.aestronglyMeasurable ?_
· exact IntegrableOn.integrable_indicator
(integrableOn_const.mpr (Or.inr measure_Ioc_lt_top)) measurableSet_Ioc
· simp_rw [Real.norm_of_nonneg (statInfoFun_nonneg 1 _ 0),
statInfoFun_of_one_of_right_le_one zero_le_one, sub_zero]
exact eventually_of_forall fun x ↦ Set.indicator_le_indicator' fun hx ↦ hx.2
rw [EReal.coe_ennreal_ofReal, max_eq_left (integral_nonneg_of_ae <| eventually_of_forall
fun x ↦ statInfoFun_nonneg 1 x 0), ← integral_indicator measurableSet_Ioc]
simp_rw [statInfoFun_of_one_of_right_le_one zero_le_one, sub_zero]

lemma fDiv_eq_lintegral_fDiv_statInfoFun_of_mutuallySingular [IsFiniteMeasure μ] [IsFiniteMeasure ν]
(hf_cvx : ConvexOn ℝ univ f) (hf_cont : Continuous f) (h_ms : μ ⟂ₘ ν) :
fDiv f μ ν = ∫⁻ x, (fDiv (statInfoFun 1 x) μ ν).toENNReal ∂(curvatureMeasure f)
+ f 1 * ν univ + rightDeriv f 1 * (μ univ - ν univ) := by
have hf_cvx' : ConvexOn ℝ (Ici 0) f := (hf_cvx.subset (fun _ _ ↦ trivial) (convex_Ici 0))
have h1 : ∫⁻ x, (statInfoFun 1 x 0 * (ν univ).toEReal
+ derivAtTop (statInfoFun 1 x) * μ univ).toENNReal ∂curvatureMeasure f
= (∫⁻ x, ENNReal.ofReal (statInfoFun 1 x 0) ∂curvatureMeasure f) * ν univ
+ (∫⁻ x, (derivAtTop (statInfoFun 1 x)).toENNReal ∂curvatureMeasure f) * μ univ := by
rw [← lintegral_mul_const _ (Measurable.ennreal_ofReal measurable_statInfoFun2),
← lintegral_mul_const _]
swap
· simp_rw [derivAtTop_statInfoFun_eq]
refine (Measurable.ite (MeasurableSet.const _) ?_ ?_).coe_real_ereal.ereal_toENNReal <;>
· refine Measurable.ite (measurableSet_le (fun _ a ↦ a) ?_) ?_ ?_ <;> exact measurable_const
rw [← lintegral_add_left]
swap; · exact measurable_statInfoFun2.ennreal_ofReal.mul_const _
congr with x
rw [EReal.toENNReal_add]
rotate_left
· exact mul_nonneg (EReal.coe_nonneg.mpr (statInfoFun_nonneg 1 x 0))
(EReal.coe_ennreal_nonneg _)
· exact mul_nonneg (derivAtTop_statInfoFun_nonneg 1 x) (EReal.coe_ennreal_nonneg _)
rw [EReal.toENNReal_mul (EReal.coe_nonneg.mpr <| statInfoFun_nonneg 1 x 0),
EReal.toENNReal_mul (derivAtTop_statInfoFun_nonneg 1 x)]
simp [-statInfoFun_of_one]
have h2 : ∫⁻ x, (derivAtTop (statInfoFun 1 x)).toENNReal ∂curvatureMeasure f
= (derivAtTop f - rightDeriv f 1).toENNReal := by
calc
_ = curvatureMeasure f (Ioi 1) := by
simp_rw [derivAtTop_statInfoFun_eq, ← lintegral_indicator_one measurableSet_Ioi]
congr with x
by_cases h : x ∈ Ioi 1
· simpa [h]
· simp [h, show x ≤ 1 from le_of_not_lt h]
_ = (derivAtTop f - rightDeriv f 1).toENNReal := by
rw [curvatureMeasure_of_convexOn hf_cvx]
by_cases h_top : derivAtTop f = ⊤
· rw [h_top, EReal.top_sub_coe, EReal.toENNReal_top,
StieltjesFunction.measure_Ioi_of_tendsto_atTop_atTop]
exact hf_cvx'.derivAtTop_eq_top_iff.mp h_top
· lift (derivAtTop f) to ℝ using ⟨h_top, hf_cvx'.derivAtTop_ne_bot⟩ with x hx
rw [StieltjesFunction.measure_Ioi _ ?_ 1 (l := x)]
· norm_cast
exact (hx ▸ hf_cvx'.tendsto_toReal_derivAtTop (hx ▸ h_top) :)
simp_rw [fDiv_of_mutuallySingular h_ms, h1]
push_cast
rw [lintegral_statInfoFun_one_zero hf_cvx hf_cont, h2, EReal.coe_toENNReal]
swap
· rw [EReal.sub_nonneg (EReal.coe_ne_top _) (EReal.coe_ne_bot _)]
exact rightDeriv_le_derivAtTop hf_cvx' zero_lt_one
simp_rw [sub_eq_add_neg, ← ENNReal.toReal_toEReal_of_ne_top (measure_ne_top ν _),
← ENNReal.toReal_toEReal_of_ne_top (measure_ne_top μ _),
EReal.add_mul_coe_of_nonneg ENNReal.toReal_nonneg, ← EReal.coe_neg (ν univ).toReal,
← EReal.coe_add, ← EReal.coe_mul _ (_ + _), mul_add, EReal.coe_add, neg_mul, ← EReal.coe_mul,
mul_neg, EReal.coe_neg, add_assoc]
congr
simp_rw [add_comm (rightDeriv f 1 * (ν _).toReal).toEReal, add_assoc,
add_comm _ (rightDeriv f 1 * _).toEReal, ← add_assoc, ← sub_eq_add_neg,
EReal.add_sub_cancel_right, sub_eq_add_neg, add_assoc, add_comm _ (_ + (_ + (_ + _))),
add_comm (f 1 * _).toEReal, ← add_assoc, ← sub_eq_add_neg, EReal.add_sub_cancel_right,
sub_eq_add_neg, add_assoc, add_comm (-(rightDeriv f 1 * _).toEReal), ← add_assoc,
← sub_eq_add_neg, EReal.add_sub_cancel_right]

lemma fDiv_eq_lintegral_fDiv_statInfoFun [IsFiniteMeasure μ] [IsFiniteMeasure ν]
(hf_cvx : ConvexOn ℝ univ f) (hf_cont : Continuous f) :
fDiv f μ ν = ∫⁻ x, (fDiv (statInfoFun 1 x) μ ν).toENNReal ∂(curvatureMeasure f)
+ f 1 * ν univ + rightDeriv f 1 * (μ univ - ν univ) := by
rw [fDiv_eq_add_withDensity_singularPart _ _ (hf_cvx.subset (fun _ _ ↦ trivial) (convex_Ici 0)),
fDiv_eq_lintegral_fDiv_statInfoFun_of_mutuallySingular hf_cvx hf_cont
(μ.mutuallySingular_singularPart ν), fDiv_eq_lintegral_fDiv_statInfoFun_of_absolutelyContinuous
hf_cvx hf_cont (withDensity_absolutelyContinuous ν (∂μ/∂ν))]
have h1 : ∫⁻ x, (fDiv (statInfoFun 1 x) μ ν).toENNReal ∂curvatureMeasure f
= ∫⁻ x, (fDiv (statInfoFun 1 x) (ν.withDensity (∂μ/∂ν)) ν).toENNReal ∂curvatureMeasure f
+ ∫⁻ x, (fDiv (statInfoFun 1 x) (μ.singularPart ν) ν).toENNReal ∂curvatureMeasure f
- (∫⁻ x, .ofReal (statInfoFun 1 x 0) ∂curvatureMeasure f : EReal) * (ν univ).toReal := by
have h_nonneg (x : ℝ) : 0 ≤ fDiv (statInfoFun 1 x) μ ν := fDiv_statInfoFun_nonneg
simp_rw [fDiv_eq_add_withDensity_singularPart μ ν ((convexOn_statInfoFun 1 _).subset
(fun _ _ ↦ trivial) (convex_Ici 0))] at h_nonneg ⊢
rw_mod_cast [← lintegral_add_left]
swap; · exact ((fDiv_statInfoFun_stronglyMeasurable (ν.withDensity (∂μ/∂ν)) ν).measurable.comp
(by fun_prop) (f := fun x ↦ (1, x))).ereal_toENNReal
simp_rw [← EReal.toENNReal_add fDiv_statInfoFun_nonneg fDiv_statInfoFun_nonneg]
have h_ne_top : (∫⁻ x, .ofReal (statInfoFun 1 x 0) ∂curvatureMeasure f) * ν univ ≠ ⊤ := by
refine ENNReal.mul_ne_top (lt_top_iff_ne_top.mp ?_) (measure_ne_top ν _)
calc
_ ≤ ∫⁻ x, (Ioc 0 1).indicator 1 x ∂curvatureMeasure f := by
simp_rw [statInfoFun_of_one_of_right_le_one zero_le_one, sub_zero]
refine lintegral_mono (le_indicator ?_ ?_) <;> simp_all
_ < _ := by
rw [lintegral_indicator_one measurableSet_Ioc]
exact measure_Ioc_lt_top
have h_le (x : ℝ) : .ofReal (statInfoFun 1 x 0) * ν univ
≤ (fDiv (statInfoFun 1 x) (ν.withDensity (∂μ/∂ν)) ν
+ fDiv (statInfoFun 1 x) (μ.singularPart ν) ν).toENNReal := by
rw [← EReal.real_coe_toENNReal, ← EReal.toENNReal_coe (x := ν _),
← EReal.toENNReal_mul (EReal.coe_nonneg.mpr <| statInfoFun_nonneg 1 x 0)]
refine EReal.toENNReal_le_toENNReal <| (EReal.sub_nonneg ?_ ?_).mp (h_nonneg x)
<;> simp [EReal.mul_ne_top, EReal.mul_ne_bot, measure_ne_top ν univ]
rw [ENNReal.toReal_toEReal_of_ne_top (measure_ne_top ν _), ← EReal.coe_ennreal_mul,
← ENNReal.toEReal_sub h_ne_top]
swap
· exact lintegral_mul_const' _ _ (measure_ne_top ν _) ▸ lintegral_mono fun x ↦ h_le x
rw [← lintegral_mul_const' _ _ (measure_ne_top ν _),
← lintegral_sub (measurable_statInfoFun2.ennreal_ofReal.mul_const _)
(lintegral_mul_const' _ _ (measure_ne_top ν _) ▸ h_ne_top)
(eventually_of_forall fun x ↦ h_le x)]
congr with x
rw [EReal.toENNReal_sub (mul_nonneg (EReal.coe_nonneg.mpr (statInfoFun_nonneg 1 x 0))
(EReal.coe_ennreal_nonneg _)),
EReal.toENNReal_mul (EReal.coe_nonneg.mpr (statInfoFun_nonneg 1 x 0)), EReal.toENNReal_coe]
congr
simp_rw [h1, lintegral_statInfoFun_one_zero hf_cvx hf_cont, sub_eq_add_neg, add_assoc]
congr 1
simp_rw [add_comm (- (((f 0).toEReal + _) * _)), add_comm (∫⁻ _, _ ∂_).toEReal _, ← add_assoc,
← ENNReal.toReal_toEReal_of_ne_top (measure_ne_top _ _)]
norm_cast
ring_nf
simp_rw [sub_eq_add_neg, mul_assoc, ← mul_neg, ← mul_add]
congr 1
nth_rw 3 [μ.haveLebesgueDecomposition_add ν]
rw [Measure.coe_add, Pi.add_apply, ENNReal.toReal_add (measure_ne_top _ _) (measure_ne_top _ _)]
ring_nf

end StatInfoFun

section DataProcessingInequality
Expand All @@ -923,21 +1068,55 @@ lemma fDiv_statInfoFun_comp_right_le [IsFiniteMeasure μ] [IsFiniteMeasure ν]
· exact EReal.coe_ennreal_le_coe_ennreal_iff.mpr <| statInfo_comp_le _ _ _ _
· simp_rw [Measure.comp_apply_univ, le_refl]

-- The name is `fDiv_comp_right_le'`, since there is already `fDiv_comp_right_le` in the `fDiv.CompProd` file.
/-- **Data processing inequality** for the f-divergence. -/
lemma fDiv_comp_right_le_of_absolutelyContinuous [IsFiniteMeasure μ] [IsFiniteMeasure ν]
(η : Kernel 𝒳 𝒳') [IsMarkovKernel η]
(hf_cvx : ConvexOn ℝ univ f) (hf_cont : Continuous f) (h_ac : μ ≪ ν) :
lemma fDiv_comp_right_le' [IsFiniteMeasure μ] [IsFiniteMeasure ν]
(η : Kernel 𝒳 𝒳') [IsMarkovKernel η] (hf_cvx : ConvexOn ℝ univ f) (hf_cont : Continuous f) :
fDiv f (η ∘ₘ μ) (η ∘ₘ ν) ≤ fDiv f μ ν := by
rw [fDiv_eq_lintegral_fDiv_statInfoFun_of_absolutelyContinuous hf_cvx hf_cont h_ac,
fDiv_eq_lintegral_fDiv_statInfoFun_of_absolutelyContinuous hf_cvx hf_cont]
swap; · exact Measure.absolutelyContinuous_comp_left h_ac _
simp_rw [Measure.comp_apply_univ]
simp_rw [fDiv_eq_lintegral_fDiv_statInfoFun hf_cvx hf_cont, Measure.comp_apply_univ]
gcongr
simp only [EReal.coe_ennreal_le_coe_ennreal_iff]
exact lintegral_mono fun x ↦ EReal.toENNReal_le_toENNReal <|
fDiv_statInfoFun_comp_right_le η zero_le_one

end DataProcessingInequality
lemma le_fDiv_compProd' [IsFiniteMeasure μ] [IsFiniteMeasure ν]
(κ η : Kernel 𝒳 𝒳') [IsMarkovKernel κ] [IsMarkovKernel η] (hf_cvx : ConvexOn ℝ univ f) (hf_cont : Continuous f) :
fDiv f μ ν ≤ fDiv f (μ ⊗ₘ κ) (ν ⊗ₘ η) := by
nth_rw 1 [← Measure.fst_compProd μ κ, ← Measure.fst_compProd ν η]
simp_rw [Measure.fst, ← Measure.comp_deterministic_eq_map measurable_fst]
exact fDiv_comp_right_le' _ hf_cvx hf_cont

lemma fDiv_compProd_right' [IsFiniteMeasure μ] [IsFiniteMeasure ν]
(κ : Kernel 𝒳 𝒳') [IsMarkovKernel κ] (hf_cvx : ConvexOn ℝ univ f) (hf_cont : Continuous f) :
fDiv f (μ ⊗ₘ κ) (ν ⊗ₘ κ) = fDiv f μ ν := by
refine le_antisymm ?_ (le_fDiv_compProd' κ κ hf_cvx hf_cont)
simp_rw [Measure.compProd_eq_comp]
exact fDiv_comp_right_le' _ hf_cvx hf_cont

lemma fDiv_comp_le_compProd' [IsFiniteMeasure μ] [IsFiniteMeasure ν]
(κ η : Kernel 𝒳 𝒳') [IsMarkovKernel κ] [IsMarkovKernel η] (hf_cvx : ConvexOn ℝ univ f) (hf_cont : Continuous f) :
fDiv f (κ ∘ₘ μ) (η ∘ₘ ν) ≤ fDiv f (μ ⊗ₘ κ) (ν ⊗ₘ η) := by
nth_rw 1 [← Measure.snd_compProd μ κ, ← Measure.snd_compProd ν η]
simp_rw [Measure.snd, ← Measure.comp_deterministic_eq_map measurable_snd]
exact fDiv_comp_right_le' _ hf_cvx hf_cont

lemma fDiv_comp_le_compProd_right' [IsFiniteMeasure μ]
(κ η : Kernel 𝒳 𝒳') [IsMarkovKernel κ] [IsMarkovKernel η] (hf_cvx : ConvexOn ℝ univ f) (hf_cont : Continuous f) :
fDiv f (κ ∘ₘ μ) (η ∘ₘ μ) ≤ fDiv f (μ ⊗ₘ κ) (μ ⊗ₘ η) :=
fDiv_comp_le_compProd' κ η hf_cvx hf_cont

lemma fDiv_fst_le' (μ ν : Measure (𝒳 × 𝒳')) [IsFiniteMeasure μ] [IsFiniteMeasure ν]
(hf_cvx : ConvexOn ℝ univ f) (hf_cont : Continuous f) :
fDiv f μ.fst ν.fst ≤ fDiv f μ ν := by
simp_rw [Measure.fst, ← Measure.comp_deterministic_eq_map measurable_fst]
exact fDiv_comp_right_le' _ hf_cvx hf_cont

lemma fDiv_snd_le' (μ ν : Measure (𝒳 × 𝒳')) [IsFiniteMeasure μ] [IsFiniteMeasure ν]
(hf_cvx : ConvexOn ℝ univ f) (hf_cont : Continuous f) :
fDiv f μ.snd ν.snd ≤ fDiv f μ ν := by
simp_rw [Measure.snd, ← Measure.comp_deterministic_eq_map measurable_snd]
exact fDiv_comp_right_le' _ hf_cvx hf_cont

end DataProcessingInequality

end ProbabilityTheory
Loading

0 comments on commit b209c28

Please sign in to comment.