Skip to content

Commit

Permalink
jfdist_cond done
Browse files Browse the repository at this point in the history
  • Loading branch information
affeldt-aist committed Jul 17, 2024
1 parent a9b3b12 commit 57db054
Showing 1 changed file with 71 additions and 61 deletions.
132 changes: 71 additions & 61 deletions probability/jfdist_cond.v
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
(* Copyright (C) 2020 infotheo authors, license: LGPL-2.1-or-later *)
From mathcomp Require Import all_ssreflect ssralg ssrnum matrix.
From mathcomp Require boolp.
From mathcomp Require Import Rstruct.
Require Import Reals.
Require Import ssrR realType_ext Reals_ext logb ssr_ext ssralg_ext bigop_ext.
From mathcomp Require Import reals.
Require Import realType_ext realType_logb ssr_ext ssralg_ext bigop_ext.
Require Import fdist proba.

(******************************************************************************)
Expand Down Expand Up @@ -34,12 +33,15 @@ Set Implicit Arguments.
Unset Strict Implicit.
Import Prenex Implicits.

Local Open Scope R_scope.
Local Open Scope ring_scope.
Local Open Scope proba_scope.
Local Open Scope fdist_scope.

Import GRing.Theory.

Section conditional_probability.
Variables (A B : finType) (P : {fdist A * B}).
Context {R : realType}.
Variables (A B : finType) (P : R.-fdist (A * B)).
Implicit Types (E : {set A}) (F : {set B}).

Definition jcPr E F := Pr P (E `* F) / Pr (P`2) F.
Expand Down Expand Up @@ -98,7 +100,7 @@ Hypothesis cov : cover (F @: I) = [set: B].
Lemma jtotal_prob_cond : Pr P`1 E = \sum_(i in I) \Pr_[E | F i] * Pr P`2 (F i).
Proof.
rewrite -Pr_XsetT -EsetT.
rewrite (@total_prob_cond _ _ _ _ (fun i => T`* F i)); last 2 first.
rewrite (@total_prob_cond _ _ _ _ _ (fun i => T`* F i)); last 2 first.
- move=> i j ij; rewrite -setI_eq0 !setTE setIX setTI.
by move: (dis ij); rewrite -setI_eq0 => /eqP ->; rewrite setX0.
- (* TODO: lemma? *) apply/setP => -[a b]; rewrite inE /cover.
Expand All @@ -125,7 +127,8 @@ Notation jcPr_cplt := jcPr_setC (only parsing).
Notation jcPr_union_eq := jcPr_setU (only parsing).

Section jPr_Pr.
Variables (U : finType) (P : {fdist U}) (A B : finType).
Context {R : realType}.
Variables (U : finType) (P : R.-fdist U) (A B : finType).
Variables (X : {RV P -> A}) (Y : {RV P -> B}) (E : {set A}) (F : {set B}).

Lemma jPr_Pr : \Pr_(`p_[% X, Y]) [E | F] = `Pr[X \in E |Y \in F].
Expand All @@ -139,13 +142,14 @@ Qed.
End jPr_Pr.

Section bayes.
Variables (A B : finType) (PQ : {fdist A * B}).
Context {R : realType}.
Variables (A B : finType) (PQ : R.-fdist (A * B)).
Let P := PQ`1. Let Q := PQ`2. Let QP := fdistX PQ.
Implicit Types (E : {set A}) (F : {set B}).

Lemma jBayes E F : \Pr_PQ[E | F] = \Pr_QP [F | E] * Pr P E / Pr Q F.
Proof.
rewrite 2!jcPrE Bayes /Rdiv -2!mulRA.
rewrite 2!jcPrE Bayes -2!mulrA.
rewrite EsetT Pr_XsetT setTE Pr_setTX /cPr; congr ((_ / _) * (_ / _)).
by rewrite EsetT setTE [in RHS]setIX Pr_fdistX setIX.
by rewrite setTE Pr_fdistX.
Expand All @@ -159,15 +163,16 @@ Lemma jBayes_extended (I : finType) (E : I -> {set A}) (F : {set B}) :
\sum_(j in I) \Pr_ QP [F | E j] * Pr P (E j).
Proof.
move=> dis cov i; rewrite jBayes; congr (_ / _).
move: (@jtotal_prob_cond _ _ QP I F E dis cov).
move: (@jtotal_prob_cond _ _ _ QP I F E dis cov).
rewrite {1}/QP fdistX1 => ->.
by apply eq_bigr => j _; rewrite -/QP {2}/QP fdistX2.
Qed.

End bayes.

Section conditional_probability_prop3.
Variables (A B C : finType) (P : {fdist A * B * C}).
Context {R : realType}.
Variables (A B C : finType) (P : R.-fdist (A * B * C)).

Lemma jcPr_TripC12 (E : {set A}) (F : {set B }) (G : {set C}) :
\Pr_(fdistC12 P)[F `* E | G] = \Pr_P[E `* F | G].
Expand Down Expand Up @@ -195,23 +200,25 @@ End conditional_probability_prop3.
Section product_rule.

Section main.
Variables (A B C : finType) (P : {fdist A * B * C}).
Context {R : realType}.
Variables (A B C : finType) (P : R.-fdist (A * B * C)).
Implicit Types (E : {set A}) (F : {set B}) (G : {set C}).

Lemma jproduct_rule_cond E F G :
\Pr_P [E `* F | G] = \Pr_(fdistA P) [E | F `* G] * \Pr_(fdist_proj23 P) [F | G].
Proof.
rewrite /jcPr; rewrite !mulRA; congr (_ * _); last by rewrite fdist_proj23_snd.
rewrite -mulRA -/(fdist_proj23 _) -Pr_fdistA.
case/boolP : (Pr (fdist_proj23 P) (F `* G) == 0) => H; last by rewrite mulVR ?mulR1.
suff -> : Pr (fdistA P) (E `* (F `* G)) = 0 by rewrite mul0R.
rewrite /jcPr; rewrite !mulrA; congr (_ * _); last by rewrite fdist_proj23_snd.
rewrite -mulrA -/(fdist_proj23 _) -Pr_fdistA.
case/boolP : (Pr (fdist_proj23 P) (F `* G) == 0) => H; last by rewrite mulVf ?mulr1.
suff -> : Pr (fdistA P) (E `* (F `* G)) = 0 by rewrite mul0r.
by rewrite Pr_fdistA; exact/Pr_fdist_proj23_domin/eqP.
Qed.

End main.

Section variant.
Variables (A B C : finType) (P : {fdist A * B * C}).
Context {R : realType}.
Variables (A B C : finType) (P : R.-fdist (A * B * C)).
Implicit Types (E : {set A}) (F : {set B}) (G : {set C}).

Lemma product_ruleC E F G :
Expand All @@ -221,47 +228,48 @@ Proof. by rewrite -jcPr_TripC12 jproduct_rule_cond. Qed.
End variant.

Section prod.
Variables (A B : finType) (P : {fdist A * B}).
Context {R : realType}.
Variables (A B : finType) (P : R.-fdist (A * B)).
Implicit Types (E : {set A}) (F : {set B}).

Lemma jproduct_rule E F : Pr P (E `* F) = \Pr_P[E | F] * Pr (P`2) F.
Proof.
have [/eqP PF0|PF0] := boolP (Pr (P`2) F == 0).
rewrite jcPrE /cPr -{1}(setIT E) -{1}(setIT F) -setIX.
rewrite [LHS]Pr_domin_setI; last by rewrite -Pr_fdistX Pr_domin_setX // fdistX1.
by rewrite setIC Pr_domin_setI ?(div0R,mul0R) // setTE Pr_setTX.
by rewrite setIC Pr_domin_setI ?mul0r // setTE Pr_setTX.
rewrite -{1}(setIT E) -{1}(setIT F) -setIX product_rule.
rewrite -EsetT setTT cPrET Pr_setT mulR1 jcPrE.
rewrite -EsetT setTT cPrET Pr_setT mulr1 jcPrE.
rewrite /cPr {1}setTE {1}EsetT.
by rewrite setIX setTI setIT setTE Pr_setTX -mulRA mulVR ?mulR1.
by rewrite setIX setTI setIT setTE Pr_setTX -mulrA mulVf ?mulr1.
Qed.

End prod.

End product_rule.

Lemma jcPr_fdistmap_r (A B B' : finType) (f : B -> B') (d : {fdist A * B})
Lemma jcPr_fdistmap_r {R : realType} (A B B' : finType) (f : B -> B') (d : R.-fdist (A * B))
(E : {set A}) (F : {set B}): injective f ->
\Pr_d [E | F] = \Pr_(fdistmap (fun x => (x.1, f x.2)) d) [E | f @: F].
Proof.
move=> injf; rewrite /jcPr; congr (_ / _).
- rewrite (@Pr_fdistmap _ _ (fun x => (x.1, f x.2))) /=; last first.
- rewrite (@Pr_fdistmap _ _ _ (fun x => (x.1, f x.2))) /=; last first.
by move=> [? ?] [? ?] /= [-> /injf ->].
congr (Pr _ _); apply/setP => -[a b]; rewrite !inE /=.
apply/imsetP/andP.
- case=> -[a' b']; rewrite inE /= => /andP[a'E b'F] [->{a} ->{b}]; split => //.
apply/imsetP; by exists b'.
- case=> aE /imsetP[b' b'F] ->{b}; by exists (a, b') => //; rewrite inE /= aE.
by rewrite /fdist_snd fdistmap_comp (@Pr_fdistmap _ _ f) // fdistmap_comp.
by rewrite /fdist_snd fdistmap_comp (@Pr_fdistmap _ _ _ f) // fdistmap_comp.
Qed.
Arguments jcPr_fdistmap_r [A] [B] [B'] [f] [d] [E] [F] _.
Arguments jcPr_fdistmap_r {R} [A] [B] [B'] [f] [d] [E] [F] _.

Lemma jcPr_fdistmap_l (A A' B : finType) (f : A -> A') (d : {fdist A * B})
Lemma jcPr_fdistmap_l {R : realType} (A A' B : finType) (f : A -> A') (d : R.-fdist (A * B))
(E : {set A}) (F : {set B}): injective f ->
\Pr_d [E | F] = \Pr_(fdistmap (fun x => (f x.1, x.2)) d) [f @: E | F].
Proof.
move=> injf; rewrite /jcPr; congr (_ / _).
- rewrite (@Pr_fdistmap _ _ (fun x => (f x.1, x.2))) /=; last first.
- rewrite (@Pr_fdistmap _ _ _ (fun x => (f x.1, x.2))) /=; last first.
by move=> [? ?] [? ?] /= [/injf -> ->].
congr (Pr _ _); apply/setP => -[a b]; rewrite !inE /=.
apply/imsetP/andP.
Expand All @@ -270,48 +278,50 @@ move=> injf; rewrite /jcPr; congr (_ / _).
- by case=> /imsetP[a' a'E] ->{a} bF; exists (a', b) => //; rewrite inE /= a'E.
by rewrite /fdist_snd !fdistmap_comp.
Qed.
Arguments jcPr_fdistmap_l [A] [A'] [B] [f] [d] [E] [F] _.
Arguments jcPr_fdistmap_l {R} [A] [A'] [B] [f] [d] [E] [F] _.

Lemma Pr_jcPr_unit (A : finType) (E : {set A}) (P : {fdist A}) :
Lemma Pr_jcPr_unit {R : realType} (A : finType) (E : {set A}) (P : R.-fdist A) :
Pr P E = \Pr_(fdistmap (fun a => (a, tt)) P) [E | setT].
Proof.
rewrite /jcPr/= (_ : [set: unit] = [set tt]); last first.
by apply/setP => -[]; rewrite !inE eqxx.
rewrite (Pr_set1 _ tt).
rewrite (_ : _`2 = fdist1 tt) ?fdist1xx ?divR1; last first.
rewrite (_ : _`2 = fdist1 tt) ?fdist1xx ?divr1; last first.
rewrite /fdist_snd fdistmap_comp; apply/fdist_ext; case.
by rewrite fdistmapE fdist1xx (eq_bigl xpredT) // FDist.f1.
rewrite /Pr big_setX /=; apply eq_bigr => a _; rewrite (big_set1 _ tt) /=.
rewrite /Pr big_setX /=; apply: eq_bigr => a _; rewrite (big_set1 _ tt) /=.
rewrite fdistmapE (big_pred1 a) // => a0; rewrite inE /=.
by apply/eqP/eqP => [[] -> | ->].
Qed.

Section jfdist_cond0.
Variables (A B : finType) (PQ : {fdist (A * B)}) (a : A).
Context {R : realType}.
Variables (A B : finType) (PQ : R.-fdist (A * B)) (a : A).
Hypothesis Ha : PQ`1 a != 0.

Let f := [ffun b => \Pr_(fdistX PQ) [[set b] | [set a]]].

Let f0 b : 0 <= f b. Proof. rewrite ffunE; exact: jcPr_ge0. Qed.

Let f0' b : (0 <= f b)%O. Proof. by apply/RleP. Qed.
Let f0' b : (0 <= f b)%O. Proof. by []. Qed.

Let f1 : \sum_(b in B) f b = 1.
Proof.
under eq_bigr do rewrite ffunE.
by rewrite /jcPr -big_distrl /= PrX_snd mulRV // Pr_set1 fdistX2.
by rewrite /jcPr -big_distrl /= PrX_snd mulfV // Pr_set1 fdistX2.
Qed.

Definition jfdist_cond0 : {fdist B} := locked (@FDist.make _ _ _ f0' f1).
Definition jfdist_cond0 : R.-fdist B := locked (@FDist.make _ _ _ f0' f1).

Lemma jfdist_cond0E b : jfdist_cond0 b = \Pr_(fdistX PQ) [[set b] | [set a]].
Proof. by rewrite /jfdist_cond0; unlock; rewrite ffunE. Qed.

End jfdist_cond0.
Arguments jfdist_cond0 {A} {B} _ _ _.
Arguments jfdist_cond0 {R} {A} {B} _ _ _.

Section jfdist_cond.
Variables (A B : finType) (PQ : {fdist A * B}) (a : A).
Context {R : realType}.
Variables (A B : finType) (PQ : R.-fdist (A * B)) (a : A).
Let Ha := PQ`1 a != 0.

Let sizeB : #|B| = #|B|.-1.+1.
Expand Down Expand Up @@ -339,7 +349,7 @@ Qed.
End jfdist_cond.
Notation "P `(| a ')'" := (jfdist_cond P a).

Lemma cPr_1 (U : finType) (P : {fdist U}) (A B : finType)
Lemma cPr_1 {R : realType} (U : finType) (P : R.-fdist U) (A B : finType)
(X : {RV P -> A}) (Y : {RV P -> B}) a : `Pr[X = a] != 0 ->
\sum_(b <- fin_img Y) `Pr[ Y = b | X = a ] = 1.
Proof.
Expand All @@ -350,46 +360,47 @@ rewrite [X in _ = _ + X](eq_bigr (fun=> 0)); last first.
move=> b bY.
rewrite /Q jfdist_condE // /jcPr /Pr !(big_setX,big_set1) /= fdistXE fdistX2 fst_RV2.
rewrite -!pr_eqE' !pr_eqE.
rewrite /Pr big1 ?div0R // => u.
rewrite /Pr big1 ?mul0r // => u.
rewrite inE => /eqP[Yub ?].
exfalso.
move/negP : bY; apply.
by rewrite mem_undup; apply/mapP; exists u => //; rewrite mem_enum.
rewrite big_const iter_addR mulR0 addR0.
rewrite big_const iter_addr mul0rn !addr0.
rewrite big_uniq; last by rewrite /fin_img undup_uniq.
apply eq_bigr => b; rewrite mem_undup => /mapP[u _ bWu].
rewrite /Q jfdist_condE // fdistX_RV2.
by rewrite jcPrE -cpr_inE' cpr_eq_set1.
Qed.

Lemma jcPr_1 (A B : finType) (P : {fdist A * B}) a : P`1 a != 0 ->
Lemma jcPr_1 {R : realType} (A B : finType) (P : R.-fdist (A * B)) a : P`1 a != 0 ->
\sum_(b in B) \Pr_(fdistX P)[ [set b] | [set a] ] = 1.
Proof.
move=> Xa0; rewrite -[RHS](FDist.f1 (P `(| a ))); apply eq_bigr => b _.
by rewrite jfdist_condE.
Qed.

Lemma jfdist_cond_prod (A B : finType) (P : {fdist A}) (W : A -> {fdist B}) (a : A) :
Lemma jfdist_cond_prod {R : realType} (A B : finType) (P : R.-fdist A) (W : A -> R.-fdist B) (a : A) :
(P `X W)`1 a != 0 -> W a = (P `X W) `(| a ).
Proof.
move=> a0; apply/fdist_ext => b.
rewrite jfdist_condE // /jcPr setX1 !Pr_set1 fdistXE fdistX2 fdist_prod1.
rewrite fdist_prodE /= /Rdiv mulRAC mulRV ?mul1R //.
rewrite fdist_prodE /= mulrAC mulfV ?mul1r //.
by move: a0; rewrite fdist_prod1.
Qed.

Lemma jcPr_fdistX_prod (A B : finType) (P : {fdist A}) (W : A -> {fdist B}) a b :
Lemma jcPr_fdistX_prod {R : realType} (A B : finType) (P : R.-fdist A) (W : A -> R.-fdist B) a b :
P a <> 0 -> \Pr_(fdistX (P `X W))[ [set b] | [set a] ] = W a b.
Proof.
move=> Pxa.
rewrite /jcPr setX1 fdistX2 2!Pr_set1 fdistXE fdist_prod1.
by rewrite fdist_prodE /= /Rdiv mulRAC mulRV ?mul1R //; exact/eqP.
by rewrite fdist_prodE /= mulrAC mulfV ?mul1r //; exact/eqP.
Qed.

Section fdist_split.
Context {R : realType}.
Variables (A B : finType).

Definition fdist_split (PQ : {fdist A * B}) := (PQ`1, fun x => PQ `(| x )).
Definition fdist_split (PQ : R.-fdist (A * B)) := (PQ`1, fun x => PQ `(| x )).

Lemma fdist_prodK : cancel fdist_split (uncurry (@fdist_prod _ A B)).
Proof.
Expand All @@ -398,19 +409,19 @@ have [Ha|Ha] := eqVneq (PQ`1 ab.1) 0.
rewrite Ha GRing.mul0r; apply/esym/(dominatesE (Prod_dominates_Joint PQ)).
by rewrite fdist_prodE Ha GRing.mul0r.
rewrite jfdist_condE // -fdistX2 GRing.mulrC.
rewrite -(Pr_set1 _ ab.1) -RmultE -jproduct_rule setX1 Pr_set1 fdistXE.
rewrite -(Pr_set1 _ ab.1) -jproduct_rule setX1 Pr_set1 fdistXE.
by case ab.
Qed.

End fdist_split.


Import GRing.Theory Num.Theory.
Import Num.Theory.

Module FDistPart.
Section fdistpart.
Context {R: realType}.
Local Open Scope fdist_scope.
Variables (n m : nat) (K : 'I_m -> 'I_n) (e : {fdist 'I_m}) (i : 'I_n).
Variables (n m : nat) (K : 'I_m -> 'I_n) (e : R.-fdist 'I_m) (i : 'I_n).

Definition d := (fdistX (e `X (fun j => fdist1 (K j)))) `(| i).
Definition den := (fdistX (e `X (fun j => fdist1 (K j))))`1 i.
Expand All @@ -426,44 +437,43 @@ rewrite eq_sym 2!inE.
by case: eqP => // _; rewrite (mulr0,mulr1).
Qed.

Lemma dE j : fdistmap K e i != 0%coqR ->
d j = (e j * (i == K j)%:R / \sum_(j | K j == i) e j)%coqR.
Lemma dE j : fdistmap K e i != 0 ->
d j = (e j * (i == K j)%:R / \sum_(j | K j == i) e j).
Proof.
rewrite -denE => NE.
rewrite jfdist_condE // {NE} /jcPr /proba.Pr.
rewrite (big_pred1 (j,i)); last first.
by move=> k; rewrite !inE [in RHS](surjective_pairing k) xpair_eqE.
rewrite (big_pred1 i); last by move=> k; rewrite !inE.
rewrite !fdistE big_mkcond [in RHS]big_mkcond /=.
rewrite -RmultE -INRE.
congr (_ / _)%R.
under eq_bigr => k do rewrite {2}(surjective_pairing k).
rewrite -(pair_bigA _ (fun k l =>
if l == i
then e `X (fun j0 : 'I_m => fdist1 (K j0)) (k, l)
else R0))%R /=.
else 0))%R /=.
apply eq_bigr => k _.
rewrite -big_mkcond /= big_pred1_eq !fdistE /= eq_sym.
by case: ifP; rewrite (mulr1,mulr0).
Qed.
End fdistpart.

Lemma dK n m K (e : {fdist 'I_m}) j :
Lemma dK {R : realType} n m K (e : R.-fdist 'I_m) j :
e j = (\sum_(i < n) fdistmap K e i * d K e i j)%R.
Proof.
under eq_bigr => /= a _.
have [Ka0|Ka0] := eqVneq (fdistmap K e a) 0%R.
rewrite Ka0 mul0R.
rewrite Ka0 mul0r.
have <- : (e j * (a == K j)%:R = 0)%R.
have [/eqP Kj|] := eqVneq a (K j); last by rewrite mulR0.
have [/eqP Kj|] := eqVneq a (K j); last by rewrite mulr0.
move: Ka0; rewrite fdistE /=.
by move/psumr_eq0P => -> //; rewrite ?(mul0R,inE) // eq_sym.
by move/psumr_eq0P => -> //; rewrite ?(mul0r,inE) // eq_sym.
over.
rewrite FDistPart.dE // fdistE /= mulRCA mulRV ?mulR1;
rewrite FDistPart.dE // fdistE /= mulrCA mulfV ?mulr1;
last by rewrite fdistE in Ka0.
over.
move=> /=.
rewrite (bigD1 (K j)) //= eqxx mulR1.
by rewrite big1 ?addR0 // => i /negbTE ->; rewrite mulR0.
rewrite (bigD1 (K j)) //= eqxx mulr1.
by rewrite big1 ?addr0 // => i /negbTE ->; rewrite mulr0.
Qed.
End FDistPart.

0 comments on commit 57db054

Please sign in to comment.