diff --git a/Bits.v b/Bits.v index 4e7838b..9be2f4f 100644 --- a/Bits.v +++ b/Bits.v @@ -1,5 +1,5 @@ -Require Export Pad. Require Export CauchySchwarz. +Require Import Modulus. Local Open Scope nat_scope. @@ -399,6 +399,15 @@ Proof. assumption. Qed. +Lemma nat_lt_pow2_funbool_to_nat_ind (P : nat -> Prop) n + (H : forall f, P (funbool_to_nat n f)) : + forall i, i < 2 ^ n -> P i. +Proof. + intros i Hi. + rewrite <- (nat_to_funbool_inverse n i Hi). + apply H. +Qed. + Local Opaque Nat.mul. Lemma nat_to_binlist'_even : forall n, (n > 0)%nat -> nat_to_binlist' (2 * n) = false :: nat_to_binlist' n. @@ -476,90 +485,6 @@ Proof. apply IHn; lia. Qed. -Lemma funbool_to_nat_inverse : forall len f i, (i < len)%nat -> - nat_to_funbool len (funbool_to_nat len f) i = f i. -Proof. - intros. - assert (list_to_funbool_append1 : forall l1 l2, - (i >= length l2)%nat -> - (len <= length l1 + length l2)%nat -> - list_to_funbool len (l1 ++ l2) i = list_to_funbool len l1 i). - { intros. - generalize dependent len. - induction l1; intros; simpl in *. - generalize dependent len. - induction l2. - reflexivity. - intros. - simpl in *. - unfold update. - bdestructΩ (i =? len - 1). - unfold update. - bdestruct (i =? len - 1). - reflexivity. - apply IHl1; lia. } - assert (list_to_funbool_append2 : forall l1 l2, - (i < length l2)%nat -> - (len >= length l1 + length l2)%nat -> - list_to_funbool len (l1 ++ l2) i = - list_to_funbool (len - length l1) l2 i). - { clear. - intros. - generalize dependent len. - induction l1; intros; simpl in *. - rewrite Nat.sub_0_r. - reflexivity. - unfold update. - bdestructΩ (i =? len - 1). - rewrite IHl1 by lia. - replace (len - 1 - length l1)%nat with (len - S (length l1))%nat by lia. - reflexivity. } - unfold nat_to_funbool, funbool_to_nat, nat_to_binlist. - remember (binlist_to_nat (funbool_to_list len f)) as n. - bdestructΩ (len - length (nat_to_binlist' n) <=? i). - rewrite list_to_funbool_append1. - all: try rewrite repeat_length; try lia. - subst. - rewrite binlist_to_nat_inverse. - clear - H. - induction len. - lia. - simpl. - rewrite Nat.sub_0_r. - bdestruct (i =? len). subst. - rewrite update_index_eq. - reflexivity. - rewrite update_index_neq by lia. - rewrite IHlen by lia. - reflexivity. - rewrite list_to_funbool_append2. - all: try rewrite repeat_length; try lia. - assert (f i = false). - { subst. - clear - H0. - induction len. - simpl in H0. lia. - remember (binlist_to_nat (funbool_to_list (S len) f)) as n. - bdestruct (n =? 0). - subst. rewrite H in *. - eapply funbool_to_nat_0. apply H. - lia. - apply IHlen. - subst. - simpl in *. - destruct (f len); simpl Nat.b2n in *. - rewrite Nat.add_comm in H0. - rewrite nat_to_binlist'_odd in H0. - simpl in H0. lia. - rewrite Nat.add_0_l in *. - rewrite nat_to_binlist'_even in H0 by lia. - simpl in H0. lia. } - rewrite list_to_funbool_repeat_false. - rewrite H1. - reflexivity. -Qed. -Local Transparent Nat.mul. - Lemma testbit_binlist {n : nat} {k : list bool} : Nat.testbit (binlist_to_nat k) n = nth n k false. Proof. @@ -568,85 +493,24 @@ Proof. intros k. - cbn. destruct k; [easy|]. - destruct b; cbn; - rewrite Nat.add_0_r. + destruct b; cbn. 2: rewrite <- Nat.negb_even; symmetry; apply negb_sym; cbn. 1: rewrite Nat.odd_succ. - all: rewrite Nat.even_add; - apply eqb_reflx. + all: rewrite Nat.even_mul; easy. - destruct k. + rewrite Nat.testbit_0_l; easy. + simpl. destruct b; simpl Nat.b2n. * rewrite Nat.add_1_l. - rewrite Nat.add_0_r, double_mult. rewrite div2_S_double. apply IHn. - * rewrite Nat.add_0_l, Nat.add_0_r, double_mult. + * rewrite Nat.add_0_l. rewrite Nat.div2_double. apply IHn. Qed. -Lemma binlist_mod {k : list bool} {n0 : nat} : - (binlist_to_nat k) mod (2^n0) = binlist_to_nat (firstn n0 k). -Proof. - apply Nat.bits_inj. - intros n. - rewrite testbit_binlist. - bdestruct (n if n <=? (n0 - 1) then nth (n0 - S n) k false else false. Proof. @@ -703,14 +574,6 @@ Proof. rewrite list_to_funbool_eq. easy. Qed. -Lemma nth_nat_to_binlist {len n} : forall k, - nth k (nat_to_binlist len n) false = Nat.testbit n k. -Proof. - intros k. - rewrite <- testbit_binlist, nat_to_binlist_inverse. - easy. -Qed. - Lemma nat_to_funbool_eq {n j} : nat_to_funbool n j = fun k => if k <=? n - 1 then Nat.testbit j (n - S k) else false. Proof. @@ -720,6 +583,75 @@ Proof. easy. Qed. +Lemma funbool_to_nat_inverse : forall len f i, (i < len)%nat -> + nat_to_funbool len (funbool_to_nat len f) i = f i. +Proof. + intros. + rewrite nat_to_funbool_eq. + rewrite testbit_funbool_to_nat. + rewrite sub_S_sub_S by easy. + bdestructΩ'. +Qed. +Local Transparent Nat.mul. + +Lemma binlist_mod {k : list bool} {n0 : nat} : + (binlist_to_nat k) mod (2^n0) = binlist_to_nat (firstn n0 k). +Proof. + apply Nat.bits_inj. + intros n. + rewrite testbit_binlist. + bdestruct (n nat_to_funbool n1 (j mod 2 ^ n1) k = nat_to_funbool (n0 + n1) j (k + n0). Proof. @@ -742,6 +674,57 @@ Proof. lia. Qed. +Lemma funbool_to_nat_add_pow2_join n f g m : + funbool_to_nat n f * 2 ^ m + funbool_to_nat m g = + funbool_to_nat (n + m) (fun k => if k f (n + k - (min n m))). +Proof. + apply bits_inj. + intros s. + rewrite testbit_mod_pow2, 2!testbit_funbool_to_nat. + rewrite min_ltb. + bdestructΩ'; f_equal; lia. +Qed. + +Lemma funbool_to_nat_eq_iff n f g : + (forall k, k < n -> f k = g k) <-> funbool_to_nat n f = funbool_to_nat n g. +Proof. + split; + [apply funbool_to_nat_eq|]. + intros H k Hk. + apply (f_equal (fun f => Nat.testbit f (n - S k))) in H. + revert H. + rewrite 2!testbit_funbool_to_nat. + simplify_bools_lia. + now replace (n - S (n - S k)) with k by lia. +Qed. + +Lemma nat_to_funbool_eq' n j k : + nat_to_funbool n j k = + if k <=? n - 1 then Nat.testbit j (n - S k) else false. +Proof. + now rewrite nat_to_funbool_eq. +Qed. + Fixpoint product (x y : nat -> bool) n := match n with | O => false @@ -801,3 +784,37 @@ Proof. rewrite list_to_funbool_repeat_false. reflexivity. Qed. + +Lemma nat_to_funbool_add_pow2_split i j n m + (Hi : i < 2 ^ n) (Hj : j < 2 ^ m) : + nat_to_funbool (n + m) (i * 2 ^ m + j) = + (fun s => + if s nat_to_funbool n i s = nat_to_funbool n j s) <-> + i = j. +Proof. + split; [|now intros ->]. + intros Hij. + rewrite <- (bits_inj_upto_small i j n) by assumption. + intros s Hs. + generalize (Hij (n - S s) ltac:(lia)). + rewrite 2!nat_to_funbool_eq. + simplify_bools_lia_one_kernel. + now rewrite sub_S_sub_S. +Qed. \ No newline at end of file diff --git a/CauchySchwarz.v b/CauchySchwarz.v index 943c6d9..cd2a25a 100644 --- a/CauchySchwarz.v +++ b/CauchySchwarz.v @@ -362,6 +362,40 @@ Qed. Local Close Scope nat_scope. +Lemma subnormal_matrix_le_1 {n m : nat} (A : Matrix n m) {i j} + (Hi : (i < n)%nat) (Hj : (j < m)%nat) + (Hnorm : forall i, (i < m)%nat -> norm (get_col A i) <= 1) : + Cmod (A i j) <= 1. +Proof. + apply Rle_pow_le_nonneg with 1%nat; try lra; [apply Cmod_ge_0|]. + rewrite pow1. + specialize (Hnorm j Hj). + revert Hnorm. + unfold get_col, norm, inner_product. + autounfold with U_db. + rewrite Nat.eqb_refl. + rewrite (big_sum_eq_bounded _ (fun k => RtoC (Cmod (A k j) ^ 2))) + by (intros; rewrite <- Cmod_sqr, RtoC_pow; easy). + rewrite Rsum_big_sum. + intros H. + rewrite <- sqrt_1 in H. + apply sqrt_le_0 in H; + [|apply Rsum_ge_0_on;intros;apply pow2_ge_0|lra]. + refine (Rle_trans _ _ _ _ H). + apply (Rsum_nonneg_ge_any n (fun k => Cmod (A k j) ^ 2)%R i); [easy|]. + intros; apply pow2_ge_0. +Qed. + +Lemma normal_matrix_le_1 {n m : nat} (A : Matrix n m) {i j} + (Hi : (i < n)%nat) (Hj : (j < m)%nat) + (Hnorm : forall i, (i < m)%nat -> norm (get_col A i) = 1) : + Cmod (A i j) <= 1. +Proof. + apply subnormal_matrix_le_1; [easy..|]. + intros. + rewrite Hnorm; easy + lra. +Qed. + (* We can now prove Cauchy-Schwartz for vectors with inner_product *) Lemma CS_key_lemma : forall {n} (u v : Vector n), fst ⟨ (⟨v,v⟩ .* u .+ -1 * ⟨v,u⟩ .* v), (⟨v,v⟩ .* u .+ -1 * ⟨v,u⟩ .* v) ⟩ = diff --git a/Complex.v b/Complex.v index 2fd4f91..51f5a35 100644 --- a/Complex.v +++ b/Complex.v @@ -259,17 +259,85 @@ Proof. apply Rsqr_le_abs_1 in H0 ; unfold pow; rewrite !Rmult_1_r; auto. Qed. +Lemma Cmod_ge_fst z : + fst z <= Cmod z. +Proof. + unfold Cmod. + apply sqrt_ge. + pose proof (pow2_ge_0 (snd z)). + lra. +Qed. + +Lemma Cmod_ge_snd z : + snd z <= Cmod z. +Proof. + unfold Cmod. + apply sqrt_ge. + pose proof (pow2_ge_0 (fst z)). + lra. +Qed. + +Lemma Cmod_ge_abs_fst z : + Rabs (fst z) <= Cmod z. +Proof. + unfold Cmod. + apply sqrt_ge_abs. + pose proof (pow2_ge_0 (snd z)). + lra. +Qed. + +Lemma Cmod_ge_abs_snd z : + Rabs (snd z) <= Cmod z. +Proof. + unfold Cmod. + apply sqrt_ge_abs. + pose proof (pow2_ge_0 (fst z)). + lra. +Qed. + +Lemma Cmod_plus_fst_ge_0 z : + 0 <= Cmod z + fst z. +Proof. + rewrite Rplus_comm. + apply Rplus_ge_0_of_ge_Rabs. + apply Cmod_ge_abs_fst. +Qed. + +Lemma Cmod_plus_snd_ge_0 z : + 0 <= Cmod z + snd z. +Proof. + rewrite Rplus_comm. + apply Rplus_ge_0_of_ge_Rabs. + apply Cmod_ge_abs_snd. +Qed. + +Lemma C_neq_iff : forall c d : C, c <> d <-> (fst c <> fst d \/ snd c <> snd d). +Proof. + intros [cr ci] [dr di]. + split. + - intros Hne. + destruct (Req_dec cr dr); [|now left]. + destruct (Req_dec ci di); [|now right]. + subst; easy. + - simpl. + intros []; congruence. +Qed. + Lemma C_neq_0 : forall c : C, c <> 0 -> (fst c) <> 0 \/ (snd c) <> 0. Proof. - intros. - apply Classical_Prop.not_and_or. - rewrite <- pair_equal_spec. - unfold not in *. - replace ((fst c, snd c)) with c by apply surjective_pairing. - replace (0, 0)%C with (RtoC 0) by reflexivity. - assumption. + intros c. + apply C_neq_iff. Qed. +Lemma Cinv_0 : / 0 = 0. +Proof. + lca. +Qed. + +Lemma Cdiv_0_r z : z / 0 = 0. +Proof. + lca. +Qed. (* some lemmas to help simplify addition/multiplication scenarios *) Lemma Cplus_simplify : forall (a b c d : C), @@ -307,11 +375,17 @@ Proof. apply injective_projections ; simpl ; ring. Qed. -Lemma RtoC_inv (x : R) : (x <> 0)%R -> RtoC (/ x) = / RtoC x. -Proof. intros Hx. apply injective_projections ; simpl ; field ; auto. Qed. +Lemma RtoC_inv (x : R) : RtoC (/ x) = / RtoC x. +Proof. destruct (Req_dec x 0). + - subst; now rewrite Cinv_0, Rinv_0. + - apply injective_projections ; simpl ; field ; auto. +Qed. -Lemma RtoC_div (x y : R) : (y <> 0)%R -> RtoC (x / y) = RtoC x / RtoC y. -Proof. intros Hy. apply injective_projections ; simpl ; field ; auto. Qed. +Lemma RtoC_div (x y : R) : RtoC (x / y) = RtoC x / RtoC y. +Proof. destruct (Req_dec y 0). + - subst; unfold Rdiv; now rewrite Cdiv_0_r, Rinv_0, Rmult_0_r. + - apply injective_projections ; simpl ; field ; auto. +Qed. Lemma Cplus_comm (x y : C) : x + y = y + x. Proof. apply injective_projections ; simpl ; apply Rplus_comm. Qed. @@ -371,6 +445,34 @@ rewrite Cmult_comm. now apply Cinv_r. Qed. +Lemma Cdiv_mult_r (c d : C) : d <> 0%R -> + c / d * d = c. +Proof. + intros. + unfold Cdiv. + rewrite <- Cmult_assoc, Cinv_l by easy. + apply Cmult_1_r. +Qed. + +Lemma Cdiv_mult_l (c d : C) : d <> 0%R -> + d * c / d = c. +Proof. + intros. + unfold Cdiv. + rewrite (Cmult_comm d c), <- Cmult_assoc, (Cmult_comm d), Cmult_assoc. + apply Cdiv_mult_r. + easy. +Qed. + +Lemma Cdiv_mult_l' (c d : C) : d <> 0%R -> + d * (c / d) = c. +Proof. + intros. + rewrite Cmult_comm. + apply Cdiv_mult_r. + easy. +Qed. + Lemma Cdiv_1_r : forall c, c / C1 = c. Proof. intros. lca. Qed. @@ -463,6 +565,29 @@ Proof. intros. lra. Qed. +Lemma Cmod_eq_0_iff x : Cmod x = 0 <-> x = 0. +Proof. + split; [apply Cmod_eq_0|intros ->; apply Cmod_0]. +Qed. + +Lemma Cmod_eq_C0_iff x : @eq C (Cmod x) 0 <-> x = 0. +Proof. + split; [intros H; apply Cmod_eq_0 + |intros ->; now rewrite Cmod_0]. + apply (f_equal fst) in H. + apply H. +Qed. + +Lemma Cmod_real_abs z : snd z = 0 -> Cmod z = Rabs (fst z). +Proof. + intros Hreal. + unfold Cmod. + rewrite Hreal. + rewrite Rpow_0_l, Rplus_0_r by easy. + rewrite <- pow2_abs. + now rewrite sqrt_pow2 by (apply Rabs_pos). +Qed. + Lemma Cmult_neq_0 (z1 z2 : C) : z1 <> 0 -> z2 <> 0 -> z1 * z2 <> 0. Proof. intros Hz1 Hz2 Hz. @@ -482,6 +607,13 @@ Proof. intros. easy. Qed. +Lemma Cmult_integral_iff (a b : C) : + a * b = 0 <-> (a = 0 \/ b = 0). +Proof. + split; [apply Cmult_integral|]. + intros [-> | ->]; lca. +Qed. + Lemma Cminus_eq_contra : forall r1 r2 : C, r1 <> r2 -> r1 - r2 <> 0. Proof. intros ; contradict H ; apply injective_projections ; @@ -554,6 +686,55 @@ Proof. intros. apply C0_fst_neq. easy. Qed. (** Other useful facts *) +Lemma Cmult_if_l (b : bool) (c d : C) : + (if b then c else 0%R) * d = + if b then c * d else 0%R. +Proof. + destruct b; lca. +Qed. + +Lemma Cmult_if_r (b : bool) (c d : C) : + d * (if b then c else 0%R) = + if b then d * c else 0%R. +Proof. + destruct b; lca. +Qed. + +Lemma Cmult_if_andb (b c : bool) (x y : C) : + (if b then x else 0%R) * (if c then y else 0%R) = + if b && c then x * y else 0%R. +Proof. + destruct b,c; lca. +Qed. + +Lemma Cmult_if_1_l (b : bool) (d : C) : + (if b then C1 else 0%R) * d = + if b then d else 0%R. +Proof. + destruct b; lca. +Qed. + +Lemma Cmult_if_1_r (b : bool) (d : C) : + d * (if b then C1 else 0%R) = + if b then d else 0%R. +Proof. + destruct b; lca. +Qed. + +Lemma Cmult_if_if_1_l (b c : bool) (x : C) : + (if b then C1 else 0%R) * (if c then x else 0%R) = + if b && c then x else 0%R. +Proof. + destruct b; lca. +Qed. + +Lemma Cmult_if_if_1_r (b c : bool) (x : C) : + (if b then x else 0%R) * (if c then C1 else 0%R) = + if b && c then x else 0%R. +Proof. + destruct b,c; lca. +Qed. + Lemma Copp_neq_0_compat: forall c : C, c <> 0 -> (- c)%C <> 0. Proof. intros c H. @@ -572,6 +753,18 @@ Proof. apply C0_fst_neq. apply R1_neq_R0. Qed. +Lemma C1_nonzero : C1 <> 0. +Proof. + apply RtoC_neq. + lra. +Qed. + +Lemma C2_nonzero : C2 <> 0. +Proof. + apply RtoC_neq. + lra. +Qed. + Lemma Cconj_neq_0 : forall c : C, c <> 0 -> c^* <> 0. Proof. intros. @@ -597,6 +790,49 @@ Proof. intros. apply C1_neq_C0; easy. Qed. +Lemma Cinv_eq_0_iff (a : C) : / a = C0 <-> a = 0. +Proof. + split. + - destruct (Ceq_dec a C0) as [? | H%nonzero_div_nonzero]; easy. + - intros ->. + lca. +Qed. + +Lemma Cdiv_integral_iff (a b : C) : + a / b = C0 <-> (a = C0 \/ b = C0). +Proof. + unfold Cdiv. + rewrite Cmult_integral_iff, Cinv_eq_0_iff. + reflexivity. +Qed. + +Lemma Cdiv_integral (a b : C) : + a / b = C0 -> (a = C0 \/ b = C0). +Proof. + rewrite Cdiv_integral_iff. + easy. +Qed. + +Lemma Cdiv_integral_dec (a b : C) : + a / b = C0 -> ({a = C0} + {b = C0}). +Proof. + intros H%Cdiv_integral. + destruct (Ceq_dec a 0); [now left |]. + destruct (Ceq_dec b 0); [now right |]. + exfalso. + destruct H; easy. +Defined. + +Lemma Cdiv_nonzero (c d : C) : c <> 0%R -> d <> 0%R -> + c / d <> 0%R. +Proof. + intros Hc Hd Hf; apply Hc. + apply (f_equal (Cmult d)) in Hf. + rewrite Cdiv_mult_l' in Hf; [|easy]. + rewrite Hf. + lca. +Qed. + Lemma div_real : forall (c : C), snd c = 0 -> snd (/ c) = 0. Proof. intros. @@ -624,9 +860,13 @@ Proof. intros. rewrite <- H' in H2. easy. Qed. -Lemma Cinv_mult_distr : forall c1 c2 : C, c1 <> 0 -> c2 <> 0 -> / (c1 * c2) = / c1 * / c2. +Lemma Cinv_mult_distr : forall c1 c2 : C, / (c1 * c2) = / c1 * / c2. Proof. intros. + destruct (Ceq_dec c1 0) as [?|H]; + [subst; now rewrite Cmult_0_l, !Cinv_0, Cmult_0_l|]. + destruct (Ceq_dec c2 0) as [?|H0]; + [subst; now rewrite Cmult_0_r, !Cinv_0, Cmult_0_r|]. apply c_proj_eq. - simpl. repeat rewrite Rmult_1_r. @@ -664,12 +904,15 @@ Proof. Qed. -Lemma Cinv_inv : forall c : C, c <> C0 -> / / c = c. -Proof. intros. - apply (Cmult_cancel_l (/ c)). - apply nonzero_div_nonzero; auto. - rewrite Cinv_l, Cinv_r; auto. - apply nonzero_div_nonzero; auto. +Lemma Cinv_inv : forall c : C, / / c = c. +Proof. + intros. + destruct (Ceq_dec c C0). + - subst. now rewrite 2!Cinv_0. + - apply (Cmult_cancel_l (/ c)). + apply nonzero_div_nonzero; auto. + rewrite Cinv_l, Cinv_r; auto. + apply nonzero_div_nonzero; auto. Qed. Lemma Cconj_eq_implies_real : forall c : C, c = Cconj c -> snd c = 0%R. @@ -686,8 +929,58 @@ Qed. (** * some C big_sum specific lemmas *) +Lemma times_n_C : forall n (c : C), + times_n c n = c * INR n. +Proof. + intros n c. + induction n; [lca|]. + cbn [times_n]. + rewrite S_INR, IHn. + lca. +Qed. + + Local Open Scope nat_scope. +Lemma Rsum_big_sum : forall n (f : nat -> R), + fst (big_sum (fun i => RtoC (f i)) n) = big_sum f n. +Proof. + intros. induction n. + - easy. + - simpl. rewrite IHn. + easy. +Qed. + +Lemma Re_big_sum (n : nat) (f : nat -> C) : + fst (big_sum (fun i => f i) n) = big_sum (fun i => fst (f i)) n. +Proof. + induction n; [easy|]. + simpl; f_equal; easy. +Qed. + +Lemma Im_big_sum (n : nat) (f : nat -> C) : + snd (big_sum (fun i => f i) n) = big_sum (fun i => snd (f i)) n. +Proof. + induction n; [easy|]. + simpl; f_equal; easy. +Qed. + +Lemma big_sum_real n (f : nat -> C) + (Hf : forall i, (i < n)%nat -> snd (f i) = 0%R) : + big_sum f n = big_sum (fun i => fst (f i)) n. +Proof. + rewrite (big_sum_eq_bounded _ (fun i => RtoC (fst (f i)))). + - apply c_proj_eq. + + rewrite Rsum_big_sum; easy. + + rewrite Im_big_sum. + simpl. + clear Hf. + induction n; [easy|]. + simpl; lra. + - intros i Hi. + apply c_proj_eq; [easy|]. + apply Hf; easy. +Qed. (* TODO: these should all probably have better names *) Lemma big_sum_rearrange : forall (n : nat) (f g : nat -> nat -> C), @@ -791,14 +1084,6 @@ Proof. intros. induction n. simpl; lra. Qed. -Lemma Rsum_big_sum : forall n (f : nat -> R), - fst (big_sum (fun i => RtoC (f i)) n) = big_sum f n. -Proof. - intros. induction n. - - easy. - - simpl. rewrite IHn. - easy. -Qed. Lemma big_sum_Cmod_0_all_0 : forall (f : nat -> C) (n : nat), big_sum (fun i => Cmod (f i)) n = 0 -> @@ -826,8 +1111,42 @@ Proof. induction n as [| n']. easy. Qed. +Lemma Cmod_real_nonneg_sum_ge_any n (f : nat -> C) k (Hk : (k < n)%nat) + (Hf_re : forall i, (i < n)%nat -> snd (f i) = 0) + (Hf_nonneg : forall i, (i < n)%nat -> 0 <= fst (f i)): + Cmod (f k) <= Cmod (big_sum (fun i => f i) n). +Proof. + rewrite big_sum_real by easy. + rewrite 2!Cmod_real; try apply Hf_re; + try apply Rle_ge; + try apply (Hf_nonneg k Hk); try easy. + - simpl. + apply (Rsum_nonneg_ge_any n (fun i => fst (f i))); easy. + - apply Rsum_ge_0_on. + easy. +Qed. +Lemma big_sum_if_eq_C (f : nat -> C) n k : + big_sum (fun x => if (x =? k)%nat then f x else 0%R) n = + (if (k C) n k : + big_sum (fun x => if (k =? x)%nat then f x else 0%R) n = + (if (k c = false) -> (c = true -> b = false) -> + ((if b then v else 0%R) + (if c then v else 0%R) = + if b || c then v else 0%R)%C. +Proof. + destruct b, c; simpl; intros; lca. +Qed. (** * Lemmas about Cpow *) @@ -844,6 +1163,15 @@ Proof. reflexivity. Qed. +Lemma Re_Cpow (c : C) (Hc : snd c = 0) n : + fst (Cpow c n) = pow (fst c) n. +Proof. + induction n; [easy|]. + simpl. + rewrite Hc, IHn. + lra. +Qed. + Lemma Cpow_nonzero_real : forall (r : R) (n : nat), (r <> 0 -> r ^ n <> C0)%C. Proof. intros. @@ -862,6 +1190,15 @@ Proof. apply Cmult_neq_0; easy. Qed. +Lemma Cpow_0_l : forall n, n <> O -> C0 ^ n = C0. +Proof. + intros n. + destruct n; [easy|]. + simpl. + rewrite Cmult_0_l. + reflexivity. +Qed. + Lemma Cpow_add : forall (c : C) (n m : nat), (c ^ (n + m) = c^n * c^m)%C. Proof. intros. induction n. simpl. lca. @@ -881,21 +1218,9 @@ Proof. induction n. - lca. - simpl. - rewrite IHn; try assumption. - rewrite Cinv_mult_distr. - + reflexivity. - + assert (c ^ 1 <> 0). - { - apply H. - apply Nat.le_pred_le_succ. - simpl. - apply Nat.le_0_l. - } - simpl in H0. - rewrite Cmult_1_r in H0. - assumption. - + apply H. - apply Nat.le_succ_diag_r. + rewrite IHn. + + rewrite Cinv_mult_distr. + reflexivity. + intros. apply H. apply le_S. @@ -1021,6 +1346,14 @@ Proof. reflexivity. Qed. +Lemma Cmult_conj_nonneg (c : C) : + 0 <= fst (c ^* * c)%C. +Proof. + rewrite <- Cmod_sqr, RtoC_pow. + apply pow2_ge_0. +Qed. + + Lemma Cconj_simplify : forall (c1 c2 : C), c1^* = c2^* -> c1 = c2. Proof. intros. assert (H1 : c1 ^* ^* = c2 ^* ^*). { rewrite H; easy. } @@ -1116,12 +1449,102 @@ Proof. intros. easy. Qed. +Definition Csqrt (z : C) : C := + match z with + | (a, b) => sqrt ((Cmod z + a) / 2) + Ci * (b / Rabs b) * sqrt((Cmod z - a) / 2) + end. + +(* TODO: Remove; this is in Reals past coq 8.16 *) +Lemma Req_dec_T : forall r1 r2:R, {r1 = r2} + {r1 <> r2}. +Proof. + intros r1 r2; destruct (total_order_T r1 r2) as [[H | ] | H]. + - now right; intros ->; apply (Rlt_irrefl r2). + - now left. + - now right; intros ->; apply (Rlt_irrefl r2 H). +Qed. + +Definition Csqrt_alt z := + if Req_dec_T (snd z) 0 then + if Rcase_abs (fst z) + then Ci * (√ (- fst z)) + else √ (fst z) + else + √ Cmod z * (z + Cmod z) / (Cmod (z + Cmod z)). + +Lemma Csqrt_Csqrt_alt (z : C) : + Csqrt_alt z * Csqrt_alt z = z. +Proof. + unfold Csqrt_alt. + destruct z as [a b]. + cbn [fst snd]. + destruct (Req_dec_T b 0). + - destruct (Rcase_abs a). + + field_simplify. + rewrite Ci2. + rewrite <- Cmult_assoc. + rewrite <- RtoC_mult. + rewrite sqrt_sqrt by lra. + lca. + + rewrite <- RtoC_mult, sqrt_sqrt by lra. + lca. + - assert (Hnz: RtoC (Cmod ((a,b) + Cmod (a,b))) <> 0). 1: { + rewrite Cmod_eq_C0_iff. + intros Heq. + apply (f_equal snd) in Heq. + simpl in Heq. + lra. + } + field_simplify. + 2: { easy. } + rewrite <- !RtoC_mult, sqrt_sqrt by apply Cmod_ge_0. + field_simplify. + 2: { intros Heq; apply (f_equal fst) in Heq; simpl in Heq; apply Hnz. + apply Rmult_integral in Heq. + now destruct Heq as [-> | ->]. + } + rewrite <- 2!Cmult_assoc, <- RtoC_mult, sqrt_sqrt by apply Cmod_ge_0. + field_simplify. + 2: { intros Heq; apply (f_equal fst) in Heq; simpl in Heq; apply Hnz. + apply Rmult_integral in Heq. + now destruct Heq as [-> | ->]. + } + generalize Hnz; intros Hnz'. + rewrite Cmod_eq_C0_iff in Hnz'. + unfold Cmod in *. + pose proof (pow2_ge_0 a). + pose proof (pow2_ge_0 b). + cbn [fst snd] in *. + rewrite <- RtoC_mult, !sqrt_sqrt in * by lra. + cbn [fst snd RtoC Cplus] in *. + rewrite sqrt_sqrt by (pose proof (pow2_ge_0 (a+√(a^2+b^2))); + pose proof (pow2_ge_0 (b+0)); lra). + + field_simplify_eq. + 2: {intros Hrw. apply (f_equal fst) in Hrw. cbn [fst RtoC] in Hrw. + rewrite Hrw in *. now rewrite sqrt_0 in *. } + apply c_proj_eq. + + simpl in *. + field_simplify. + simpl. + rewrite !Rmult_1_r. + rewrite sqrt_sqrt by lra. + lra. + + simpl. + field_simplify. + simpl. + rewrite !Rmult_1_r. + rewrite sqrt_sqrt by lra. + lra. +Qed. (** * Complex exponentiation **) (** Compute e^(iθ) *) Definition Cexp (θ : R) : C := (cos θ, sin θ). +Lemma Cexp_spec : forall α, Cexp α = cos α + Ci * sin α. +Proof. intros; lca. Qed. + Lemma Cexp_0 : Cexp 0 = 1. Proof. unfold Cexp. autorewrite with trig_db; easy. Qed. @@ -1154,6 +1577,16 @@ Proof. field. Qed. +Lemma Cexp_minus : forall θ, + Cexp θ + Cexp (-θ) = 2 * cos θ. +Proof. + intros. + unfold Cexp. + rewrite cos_neg. + rewrite sin_neg. + lca. +Qed. + Lemma Cexp_plus_PI : forall x, Cexp (x + PI) = (- (Cexp x))%C. Proof. @@ -1400,10 +1833,11 @@ Proof. Qed. Lemma Cexp_mod_2PI_scaled : forall (k sc : Z), - (sc <> 0)%Z -> Cexp (IZR k * PI / IZR sc) = Cexp (IZR (k mod (2 * sc)) * PI / IZR sc). Proof. - intros k sc H. + intros k sc. + destruct (Z.eq_dec sc 0) as [?|H]; + [subst; simpl; now rewrite 2!Rdiv_0_r|]. rewrite (Z.div_mod k (2 * sc)) at 1 by lia. repeat rewrite plus_IZR. unfold Rdiv. @@ -1423,6 +1857,134 @@ Qed. Cexp_1PI4 Cexp_2PI4 Cexp_3PI4 Cexp_4PI4 Cexp_5PI4 Cexp_6PI4 Cexp_7PI4 Cexp_8PI4 Cexp_add Cexp_neg Cexp_plus_PI Cexp_minus_PI : Cexp_db. +Lemma INR_pi_exp : forall (r : nat), + Cexp (INR r * PI) = 1 \/ Cexp (INR r * PI) = -1. +Proof. + intros. + dependent induction r. + - simpl. + rewrite Rmult_0_l. + left. + apply Cexp_0. + - rewrite S_O_plus_INR. + rewrite Rmult_plus_distr_r. + rewrite Rmult_1_l. + rewrite Rplus_comm. + rewrite Cexp_plus_PI. + destruct IHr. + + rewrite H; right; lca. + + rewrite H; left; lca. +Qed. + +Lemma Cexp_2_PI : forall a, Cexp (INR a * 2 * PI) = 1. +Proof. + intros. + induction a. + - simpl. + rewrite 2 Rmult_0_l. + rewrite Cexp_0. + easy. + - rewrite S_INR. + rewrite 2 Rmult_plus_distr_r. + rewrite Rmult_1_l. + rewrite double. + rewrite <- Rplus_assoc. + rewrite 2 Cexp_plus_PI. + rewrite IHa. + lca. +Qed. + +Lemma sin_sin_PI8 : + sin (PI / 8) * sin (PI / 8) = + cos (PI / 8) * cos (PI / 8) - cos (PI / 4). +Proof. + replace (PI / 4)%R with (2 * (PI / 8))%R by lra. + rewrite cos_2a. + lca. +Qed. + + +Definition CexpC (c : C) := + exp (Re c) * Cexp (Im c). + +Lemma CexpC_def (c : C) : + CexpC c = exp (Re c) * Cexp (Im c). +Proof. reflexivity. Qed. + +Lemma CexpC_add (c d : C) : + CexpC (c + d) = CexpC c * CexpC d. +Proof. + unfold CexpC, Im, Re. + cbn. + rewrite exp_plus, Cexp_add. + lca. +Qed. + +Lemma CexpC_neg (c : C) : + CexpC (-c) = / CexpC c. +Proof. + unfold CexpC, Im, Re. + cbn. + pose proof (exp_pos (fst c)). + rewrite exp_Ropp, Cexp_neg, RtoC_inv, Cinv_mult_distr. + reflexivity. +Qed. + +Lemma CexpC_minus (c d : C) : + CexpC (c - d) = CexpC c / CexpC d. +Proof. + unfold Cminus. + rewrite CexpC_add, CexpC_neg. + reflexivity. +Qed. + +Lemma CexpC_zero : CexpC 0 = 1. +Proof. + unfold CexpC. + cbn. + rewrite exp_0, Cexp_0. + lca. +Qed. + +Lemma Cmod_CexpC c : Cmod (CexpC c) = exp (Re c). +Proof. + unfold CexpC. + rewrite Cmod_mult, Cmod_Cexp, Rmult_1_r. + apply Cmod_real; [cbn | reflexivity]. + pose proof (exp_pos (Re c)). + lra. +Qed. + +Lemma Cexp_CexpC (r : R) : + Cexp r = CexpC (0, r). +Proof. + unfold CexpC. + cbn. + rewrite exp_0. + lca. +Qed. + +Lemma RtoC_exp (x : R) : + RtoC (exp x) = CexpC x. +Proof. + apply c_proj_eq; simpl; + autorewrite with trig_db; lra. +Qed. + +Lemma Cmod_1_plus_Cexp (r : R) : + Cmod (1 + Cexp r) = √ (2 + 2 * cos r)%R. +Proof. + unfold Cmod. + f_equal. + simpl. + pose proof sin2_cos2 r as H. + rewrite 2!Rsqr_pow2 in H. + field_simplify. + rewrite (Rplus_comm _ (_ ^ 2)), <- Rplus_assoc. + rewrite H. + lra. +Qed. + Opaque C. @@ -1473,20 +2035,26 @@ Ltac nonzero := #[global] Hint Rewrite Csqrt_sqrt using Psatz.lra : C_db. #[global] Hint Rewrite Cinv_l Cinv_r using nonzero : C_db. (* Previously in the other direction *) -#[global] Hint Rewrite Cinv_mult_distr using nonzero : C_db. +#[global] Hint Rewrite Cinv_mult_distr : C_db. (* Light rewriting db *) #[global] Hint Rewrite Cplus_0_l Cplus_0_r Cmult_0_l Cmult_0_r Copp_0 Cconj_R Cmult_1_l Cmult_1_r : C_db_light. (* Distributing db *) -#[global] Hint Rewrite Cmult_plus_distr_l Cmult_plus_distr_r Copp_plus_distr Copp_mult_distr_l - Copp_involutive : Cdist_db. +#[global] Hint Rewrite Cmult_plus_distr_l Cmult_plus_distr_r + Copp_plus_distr Copp_mult_distr_l Copp_involutive : Cdist_db. #[global] Hint Rewrite RtoC_opp RtoC_mult RtoC_minus RtoC_plus : RtoC_db. #[global] Hint Rewrite RtoC_inv using nonzero : RtoC_db. #[global] Hint Rewrite RtoC_pow : RtoC_db. +Lemma Copp_Ci : / Ci = - Ci. +Proof. field_simplify_eq; lca + nonzero. Qed. + +#[global] Hint Rewrite Copp_Ci : C_db. + + Ltac Csimpl := repeat match goal with | _ => rewrite Cmult_0_l @@ -1498,8 +2066,20 @@ Ltac Csimpl := | _ => rewrite Cconj_R end. +Ltac Csimpl_in H := + repeat + match goal with + | _ => rewrite Cmult_0_l in H + | _ => rewrite Cmult_0_r in H + | _ => rewrite Cplus_0_l in H + | _ => rewrite Cplus_0_r in H + | _ => rewrite Cmult_1_l in H + | _ => rewrite Cmult_1_r in H + | _ => rewrite Cconj_R in H + end. + Ltac C_field_simplify := repeat field_simplify_eq [ Csqrt2_sqrt Csqrt2_inv Ci2]. -Ltac C_field := C_field_simplify; nonzero; trivial. +Ltac C_field := C_field_simplify; nonzero || trivial; try trivial. Ltac has_term t exp := match exp with diff --git a/Eigenvectors.v b/Eigenvectors.v index 5d38996..851e9f4 100644 --- a/Eigenvectors.v +++ b/Eigenvectors.v @@ -1,13 +1,12 @@ (** This file contains more concepts relevent to quantum computing, as well as some more general linear algebra concepts such as Gram-Schmidt and eigenvectors/eigenvalues. *) +Require Import Permutations. Require Import List. Require Export Complex. Require Export CauchySchwarz. Require Export Quantum. Require Import FTA. -Require Import Permutations. - (****************************) (** * Proving some indentities *) @@ -133,6 +132,17 @@ Proof. intros n U. split. * rewrite H2; try nia; easy. Qed. +Lemma unitary_abs_le_1 {n} {A : Matrix n n} (HA: WF_Unitary A) : + forall i j, + (Cmod (A i j) <= 1)%R. +Proof. + intros i j. + bdestruct (i WF_Matrix B -> WF_Unitary X -> (forall i, A × (get_col X i) = B × (get_col X i)) -> A = B. @@ -170,6 +180,21 @@ Proof. intros. apply H. Qed. +Lemma unit_det_Cmod_1 : forall {n} (U : Square n), + WF_Unitary U -> Cmod (Determinant U) = 1%R. +Proof. + intros n U [HWF Hinv]. + apply (f_equal (fun A => √ (Cmod (Determinant A)))) in Hinv. + revert Hinv. + rewrite <- Determinant_multiplicative, <- Determinant_adjoint. + rewrite Cmod_mult, Cmod_Cconj. + let a := constr:(Cmod (Determinant U)) in + replace (a * a)%R with (a ^ 2)%R by lra. + rewrite sqrt_pow2 by apply Cmod_ge_0. + intros ->. + rewrite Det_I, Cmod_1. + apply sqrt_1. +Qed. (***********************************************************************************) (** * We now define diagonal matrices and diagonizable matrices, proving basic lemmas *) @@ -1970,7 +1995,7 @@ Proof. intros. intros. bdestruct (i hermitian A -> WF_Matrix v -> v <> Zero -> @@ -2422,7 +2443,10 @@ Proof. intros. rewrite <- inner_product_scale_l, <- inner_product_scale_r, <- H3, inner_product_adjoint_switch. rewrite H0; easy. -Qed. +Qed. + +#[deprecated(note="Use hermitian_real_eigenvalues instead")] +Notation hermitiam_real_eigenvalues := hermitian_real_eigenvalues. Lemma unitary_eigenvalues_norm_1 : forall {n} (U : Square n) (v : Vector n) (λ : C), WF_Unitary U -> WF_Matrix v -> @@ -2677,22 +2701,22 @@ Local Close Scope nat_scope. Lemma EigenXp : Eigenpair σx (∣+⟩, C1). -Proof. unfold Eigenpair. solve_matrix. Qed. +Proof. unfold Eigenpair. simpl. lma'. Qed. Lemma EigenXm : Eigenpair σx (∣-⟩, -C1). -Proof. unfold Eigenpair. solve_matrix. Qed. +Proof. unfold Eigenpair. simpl. lma'. Qed. Lemma EigenYp : Eigenpair σy (∣R⟩, C1). -Proof. unfold Eigenpair. solve_matrix. Qed. +Proof. unfold Eigenpair. simpl. lma'. Qed. Lemma EigenYm : Eigenpair σy (∣L⟩, -C1). -Proof. unfold Eigenpair. solve_matrix. Qed. +Proof. unfold Eigenpair. simpl. lma'. Qed. Lemma EigenZp : Eigenpair σz (∣0⟩, C1). -Proof. unfold Eigenpair. solve_matrix. Qed. +Proof. unfold Eigenpair. simpl. lma'. Qed. Lemma EigenZm : Eigenpair σz (∣1⟩, -C1). -Proof. unfold Eigenpair. solve_matrix. Qed. +Proof. unfold Eigenpair. simpl. lma'. Qed. Lemma EigenXXB : Eigenpair (σx ⊗ σx) (∣Φ+⟩, C1). Proof. unfold Eigenpair. lma'. Qed. diff --git a/Kronecker.v b/Kronecker.v new file mode 100644 index 0000000..91083fb --- /dev/null +++ b/Kronecker.v @@ -0,0 +1,1890 @@ +Require Export Matrix. +Require Import RowColOps. +Import Complex. +Require Import Modulus. + +Local Open Scope nat_scope. +Local Open Scope C_scope. + +Lemma kron_I_r {n m p} (A : Matrix n m) : + mat_equiv (A ⊗ I p) + (fun i j => if i mod p =? j mod p then A (i / p)%nat (j / p)%nat else C0). +Proof. + intros i j Hi Hj. + unfold kron, I. + pose proof (Nat.mod_upper_bound i p ltac:(lia)). + bdestructΩ'; lca. +Qed. + +Lemma kron_I_l {n m p} (A : Matrix n m) : + mat_equiv (I p ⊗ A) + (fun i j => if i / n =? j / m then A (i mod n) (j mod m) else C0). +Proof. + intros i j Hi Hj. + unfold kron, I. + rewrite Nat.mul_comm in Hi. + pose proof (Nat.Div0.div_lt_upper_bound _ _ _ Hi). + bdestructΩ'; lca. +Qed. + +Lemma kron_I_l_eq {n m} (A : Matrix n m) p : WF_Matrix A -> + I p ⊗ A = + (fun i j : nat => + if ((i / n =? j / m) && (i + A ⊗ I p = + (fun i j : nat => + if (i mod p =? j mod p) && (i + (* have blocks H_ij, p by q of them, and each is q by p *) + let i := (s / q)%nat in let j := (t / p)%nat in + let k := (s mod q)%nat in let l := (t mod p) in + if (i =? l) && (j =? k) then C1 else C0 +). + +Lemma WF_kron_comm p q : WF_Matrix (kron_comm p q). +Proof. unfold kron_comm; + rewrite Nat.mul_comm; + trivial with wf_db. Qed. +#[export] Hint Resolve WF_kron_comm : wf_db. + + +Lemma kron_comm_transpose_mat_equiv : forall p q, + (kron_comm p q) ⊤ ≡ kron_comm q p. +Proof. + intros p q. + intros i j Hi Hj. + unfold kron_comm, transpose, make_WF. + rewrite andb_comm, Nat.mul_comm. + rewrite (andb_comm (_ =? _)). + easy. +Qed. + +Lemma kron_comm_transpose : forall p q, + (kron_comm p q) ⊤ = kron_comm q p. +Proof. + intros p q. + apply mat_equiv_eq; auto with wf_db. + apply kron_comm_transpose_mat_equiv. +Qed. + +Lemma kron_comm_adjoint : forall p q, + (kron_comm p q) † = kron_comm q p. +Proof. + intros p q. + apply mat_equiv_eq; [auto with wf_db..|]. + unfold adjoint. + intros i j Hi Hj. + change (kron_comm p q j i) with ((kron_comm p q) ⊤ i j). + rewrite kron_comm_transpose_mat_equiv by easy. + unfold kron_comm, make_WF. + rewrite !(@if_dist C C). + bdestructΩ'; lca. +Qed. + +Lemma kron_comm_1_r_mat_equiv : forall p, + (kron_comm p 1) ≡ Matrix.I p. +Proof. + intros p. + intros s t Hs Ht. + unfold kron_comm. + unfold make_WF. + unfold Matrix.I. + rewrite Nat.mul_1_r, Nat.div_1_r, Nat.mod_1_r, Nat.div_small, Nat.mod_small by lia. + bdestructΩ'. +Qed. + +Lemma kron_comm_1_r : forall p, + (kron_comm p 1) = Matrix.I p. +Proof. + intros p. + apply mat_equiv_eq; [|rewrite Nat.mul_1_l, Nat.mul_1_r|]; auto with wf_db. + apply kron_comm_1_r_mat_equiv. +Qed. + +Lemma kron_comm_1_l_mat_equiv : forall p, + (kron_comm 1 p) ≡ Matrix.I p. +Proof. + intros p. + intros s t Hs Ht. + unfold kron_comm. + unfold make_WF. + unfold Matrix.I. + rewrite Nat.mul_1_l, Nat.div_1_r, Nat.mod_1_r, Nat.div_small, Nat.mod_small by lia. + bdestructΩ'. +Qed. + +Lemma kron_comm_1_l : forall p, + (kron_comm 1 p) = Matrix.I p. +Proof. + intros p. + apply mat_equiv_eq; [|rewrite Nat.mul_1_l, Nat.mul_1_r|]; auto with wf_db. + apply kron_comm_1_l_mat_equiv. +Qed. + +Definition mx_to_vec {n m} (A : Matrix n m) : Vector (m * n) := + make_WF (fun i j => A (i mod n)%nat (i / n)%nat + (* Note: goes columnwise. Rowwise would be: + make_WF (fun i j => A (i / m)%nat (i mod n)%nat + *) +). + +Lemma WF_mx_to_vec {n m} (A : Matrix n m) : WF_Matrix (mx_to_vec A). +Proof. unfold mx_to_vec; auto with wf_db. Qed. +#[export] Hint Resolve WF_mx_to_vec : wf_db. + +(* Compute vec_to_list (mx_to_vec (Matrix.I 2)). *) +From Coq Require Import ZArith. +Ltac Zify.zify_post_hook ::= PreOmega.Z.div_mod_to_equations. + +Lemma kron_comm_mx_to_vec_helper : forall i p q, (i < p * q)%nat -> + (p * (i mod q) + i / q < p * q)%nat. +Proof. + intros i p q Hi. + show_moddy_lt. +Qed. + +Lemma mx_to_vec_additive_mat_equiv {n m} (A B : Matrix n m) : + mx_to_vec (A .+ B) ≡ mx_to_vec A .+ mx_to_vec B. +Proof. + intros i j Hi Hj. + replace j with O by lia; clear dependent j. + unfold mx_to_vec, make_WF, Mplus. + bdestructΩ'. +Qed. + +Lemma mx_to_vec_additive {n m} (A B : Matrix n m) : + mx_to_vec (A .+ B) = mx_to_vec A .+ mx_to_vec B. +Proof. + apply mat_equiv_eq; auto with wf_db. + apply mx_to_vec_additive_mat_equiv. +Qed. + +Lemma if_mult_dist_r (b : bool) (z : C) : + (if b then C1 else C0) * z = + if b then z else C0. +Proof. + destruct b; lca. +Qed. + +Lemma if_mult_dist_l (b : bool) (z : C) : + z * (if b then C1 else C0) = + if b then z else C0. +Proof. + destruct b; lca. +Qed. + +Lemma if_mult_and (b c : bool) : + (if b then C1 else C0) * (if c then C1 else C0) = + if (b && c) then C1 else C0. +Proof. + destruct b; destruct c; lca. +Qed. + +Lemma kron_comm_mx_to_vec_mat_equiv : forall p q (A : Matrix p q), + kron_comm p q × mx_to_vec A ≡ mx_to_vec (A ⊤). +Proof. + intros p q A. + intros i j Hi Hj. + replace j with O by lia; clear dependent j. + unfold transpose, mx_to_vec, kron_comm, make_WF, Mmult. + rewrite (Nat.mul_comm q p). + replace_bool_lia (i . + destruct p; [lia|]. + destruct q; [lia|]. + split. + + rewrite Nat.add_comm, Nat.mul_comm. + rewrite Nat.Div0.mod_add by easy. + rewrite Nat.mod_small; [lia|]. + show_moddy_lt. + + rewrite Nat.mul_comm, Nat.div_add_l by easy. + rewrite Nat.div_small; [lia|]. + show_moddy_lt. + - intros [Hmodp Hdivp]. + rewrite (Nat.div_mod_eq k p). + lia. + } + apply big_sum_unique. + exists (p * (i mod q) + i / q)%nat; repeat split; + [apply kron_comm_mx_to_vec_helper; easy | rewrite Nat.eqb_refl | intros; + bdestructΩ'simp]. + destruct p; [lia|]; + destruct q; [lia|]. + f_equal. + - rewrite Nat.add_comm, Nat.mul_comm, Nat.Div0.mod_add, Nat.mod_small; try easy. + show_moddy_lt. + - rewrite Nat.mul_comm, Nat.div_add_l by easy. + rewrite Nat.div_small; [lia|]. + show_moddy_lt. +Qed. + +Lemma kron_comm_mx_to_vec : forall p q (A : Matrix p q), + kron_comm p q × mx_to_vec A = mx_to_vec (A ⊤). +Proof. + intros p q A. + apply mat_equiv_eq; auto with wf_db. + apply kron_comm_mx_to_vec_mat_equiv. +Qed. + +Lemma kron_comm_ei_kron_ei_sum_mat_equiv : forall p q, + kron_comm p q ≡ + big_sum (G:=Matrix (p*q) (q*p)) + (fun i => big_sum (fun j => + (@e_i p i ⊗ @e_i q j) × ((@e_i q j ⊗ @e_i p i) ⊤)) + q) p. +Proof. + intros p q. + intros i j Hi Hj. + rewrite Msum_Csum. + erewrite big_sum_eq_bounded. + 2: { + intros k Hk. + rewrite Msum_Csum. + erewrite big_sum_eq_bounded. + 2: { + intros l Hl. + unfold Mmult, kron, transpose, e_i. + erewrite big_sum_eq_bounded. + 2: { + intros m Hm. + (* replace m with O by lia. *) + rewrite Nat.div_1_r, Nat.mod_1_r. + replace_bool_lia (m =? 0) true; rewrite 4!andb_true_r. + rewrite 3!if_mult_and. + match goal with + |- context[if ?b then _ else _] => + replace b with ((i =? k * q + l) && (j =? l * p + k)) + end. + 1: reflexivity. (* set our new function *) + clear dependent m. + rewrite eq_iff_eq_true, 8!andb_true_iff, + 6!Nat.eqb_eq, 4!Nat.ltb_lt. + split. + - intros [Hieq Hjeq]. + subst i j. + rewrite 2!Nat.div_add_l, Nat.div_small, Nat.add_0_r by lia. + rewrite Nat.add_comm, Nat.Div0.mod_add, Nat.mod_small, + Nat.div_small, Nat.add_0_r by lia. + rewrite Nat.add_comm, Nat.Div0.mod_add, Nat.mod_small by lia. + easy. + - intros [[[] []] [[] []]]. + split. + + rewrite (Nat.div_mod_eq i q) by lia; lia. + + rewrite (Nat.div_mod_eq j p) by lia; lia. + } + simpl; rewrite Cplus_0_l. + reflexivity. + } + apply big_sum_unique. + exists (i mod q). + split; [|split]. + - apply Nat.mod_upper_bound; lia. + - reflexivity. + - intros l Hl Hnmod. + bdestructΩ'simp. + exfalso; apply Hnmod. + rewrite Nat.add_comm, Nat.Div0.mod_add, Nat.mod_small by lia; lia. + } + symmetry. + apply big_sum_unique. + exists (j mod p). + repeat split. + - apply Nat.mod_upper_bound; lia. + - unfold kron_comm, make_WF. + replace_bool_lia (i big_sum (fun j => + (@e_i p i ⊗ @e_i q j) × ((@e_i q j ⊗ @e_i p i) ⊤)) + q) p. +Proof. + intros p q. + apply mat_equiv_eq; auto 10 with wf_db. + apply kron_comm_ei_kron_ei_sum_mat_equiv. +Qed. + +Lemma kron_comm_ei_kron_ei_sum'_mat_equiv : forall p q, + kron_comm p q ≡ + big_sum (fun ij => + let i := (ij / q)%nat in let j := (ij mod q) in + ((@e_i p i ⊗ @e_i q j) × ((@e_i q j ⊗ @e_i p i) ⊤))) (p*q). +Proof. + intros p q. + rewrite kron_comm_ei_kron_ei_sum, big_sum_double_sum, Nat.mul_comm. + reflexivity. +Qed. + +(* TODO: put somewhere sensible *) +Lemma big_sum_mat_equiv_bounded : forall {o p} (f g : nat -> Matrix o p) (n : nat), + (forall x : nat, (x < n)%nat -> f x ≡ g x) -> big_sum f n ≡ big_sum g n. +Proof. + intros. + induction n. + - easy. + - simpl. + rewrite IHn, H; [easy|lia|auto]. +Qed. + +Lemma kron_comm_Hij_sum_mat_equiv : forall p q, + kron_comm p q ≡ + big_sum (fun i => big_sum (fun j => + @kron p q q p (@e_i p i × ((@e_i q j) ⊤)) + ((@Mmult p 1 q (@e_i p i) (((@e_i q j) ⊤))) ⊤)) q) p. +Proof. + intros p q. + rewrite kron_comm_ei_kron_ei_sum_mat_equiv. + apply big_sum_mat_equiv_bounded; intros i Hi. + apply big_sum_mat_equiv_bounded; intros j Hj. + rewrite kron_transpose, kron_mixed_product. + rewrite Mmult_transpose, transpose_involutive. + easy. +Qed. + +Lemma kron_comm_Hij_sum : forall p q, + kron_comm p q = + big_sum (fun i => big_sum (fun j => + e_i i × (e_i j) ⊤ ⊗ + (e_i i × (e_i j) ⊤) ⊤) q) p. +Proof. + intros p q. + apply mat_equiv_eq; [auto 10 with wf_db.. | ]. + apply kron_comm_Hij_sum_mat_equiv. +Qed. + + +Lemma kron_comm_ei_kron_ei_sum' : forall p q, + kron_comm p q = + big_sum (fun ij => + let i := (ij / q)%nat in let j := (ij mod q) in + ((e_i i ⊗ e_i j) × ((e_i j ⊗ e_i i) ⊤))) (p*q). +Proof. + intros p q. + rewrite kron_comm_ei_kron_ei_sum, big_sum_double_sum, Nat.mul_comm. + reflexivity. +Qed. + +Local Notation H := (fun i j => e_i i × (e_i j)⊤). + +Lemma kron_comm_Hij_sum'_mat_equiv : forall p q, + kron_comm p q ≡ + big_sum ( fun ij => + let i := (ij / q)%nat in let j := (ij mod q) in + H i j ⊗ (H i j) ⊤) (p*q). +Proof. + intros p q. + rewrite kron_comm_Hij_sum_mat_equiv, big_sum_double_sum, Nat.mul_comm. + easy. +Qed. + +Lemma kron_comm_Hij_sum' : forall p q, + kron_comm p q = + big_sum (fun ij => + let i := (ij / q)%nat in let j := (ij mod q) in + H i j ⊗ (H i j) ⊤) (p*q). +Proof. + intros p q. + rewrite kron_comm_Hij_sum, big_sum_double_sum, Nat.mul_comm. + easy. +Qed. + + +Lemma div_eq_iff : forall a b c, b <> O -> + (a / b)%nat = c <-> (b * c <= a /\ a < b * (S c))%nat. +Proof. + intros a b c Hb. + split. + intros Hadivb. + split; + subst c. + - rewrite (Nat.div_mod_eq a b) at 2; lia. + - now apply Nat.mul_succ_div_gt. + - intros [Hge Hlt]. + symmetry. + apply (Nat.div_unique _ _ _ (a - b*c)); lia. +Qed. + +Lemma div_eqb_iff : forall a b c, b <> O -> + (a / b)%nat =? c = ((b * c <=? a) && (a + (o <> O) -> (m <> O) -> + (@e_i n k)⊤ ⊗ A = (fun i j => + if (i + (o <> O) -> (m <> O) -> + (@e_i n k)⊤ ⊗ A ≡ (fun i j => + if (i + (@e_i n k)⊤ ⊗ A ≡ (fun i j => + if (i + (o <> O) -> (m <> O) -> + (@e_i n k) ⊗ A = (fun i j => + if (j + (o <> O) -> (m <> O) -> + (@e_i n k) ⊗ A ≡ (fun i j => + if (j + (@e_i n k) ⊗ A ≡ (fun i j => + if (j + (o <> O) -> (m <> O) -> + (@e_i n k)⊤ ⊗ A = (fun i j => + if (i ((j/o)%nat=k)) by lia; + rewrite Hrw; clear Hrw. + symmetry. + rewrite div_eq_iff by lia. + lia. + - replace (i / m =? 0) with false. + rewrite andb_false_r; easy. + symmetry. + rewrite Nat.eqb_neq. + rewrite Nat.div_small_iff; lia. +Qed. + +Lemma kron_e_i_transpose_l'_mat_equiv : forall k n m o (A : Matrix m o), (k < n)%nat -> + (o <> O) -> (m <> O) -> + (@e_i n k)⊤ ⊗ A ≡ (fun i j => + if (i + (@e_i n k)⊤ ⊗ A ≡ (fun i j => + if (i + (o <> O) -> (m <> O) -> + (@e_i n k) ⊗ A = (fun i j => + if (j ((i/m)%nat=k)) by lia; + rewrite Hrw; clear Hrw. + symmetry. + rewrite div_eq_iff by lia. + lia. + - replace (j / o =? 0) with false. + rewrite andb_false_r; easy. + symmetry. + rewrite Nat.eqb_neq. + rewrite Nat.div_small_iff; lia. +Qed. + +Lemma kron_e_i_l'_mat_equiv : forall k n m o (A : Matrix m o), (k < n)%nat -> + (o <> O) -> (m <> O) -> + (@e_i n k) ⊗ A ≡ (fun i j => + if (j + (o <> O) -> (m <> O) -> + (@e_i n k) ⊗ A ≡ (fun i j => + if (j + (o <> O) -> (m <> O) -> + A ⊗ (@e_i n k) = (fun i j => + if (i mod n =? k) then A (i / n)%nat j else 0%R). +Proof. + intros k n m o A Hk Ho Hm. + apply functional_extensionality; intros i; + apply functional_extensionality; intros j. + unfold kron, e_i. + rewrite if_mult_dist_l, Nat.div_1_r. + rewrite Nat.mod_1_r, Nat.eqb_refl, andb_true_r. + replace (i mod n + (o <> O) -> (m <> O) -> + A ⊗ (@e_i n k) ≡ (fun i j => + if (i mod n =? k) then A (i / n)%nat j else 0%R). +Proof. + intros. + rewrite kron_e_i_r; easy. +Qed. + +Lemma kron_e_i_r_mat_equiv' : forall k n m o (A : Matrix m o), (k < n)%nat -> + A ⊗ (@e_i n k) ≡ (fun i j => + if (i mod n =? k) then A (i / n)%nat j else 0%R). +Proof. + intros. + destruct m; [|destruct o]; + try (intros i j Hi Hj; lia). + rewrite kron_e_i_r; easy. +Qed. + +Lemma kron_e_i_transpose_r : forall k n m o (A : Matrix m o), (k < n)%nat -> + (o <> O) -> (m <> O) -> + A ⊗ (@e_i n k) ⊤ = (fun i j => + if (j mod n =? k) then A i (j / n)%nat else 0%R). +Proof. + intros k n m o A Hk Ho Hm. + apply functional_extensionality; intros i; + apply functional_extensionality; intros j. + unfold kron, transpose, e_i. + rewrite if_mult_dist_l, Nat.div_1_r. + rewrite Nat.mod_1_r, Nat.eqb_refl, andb_true_r. + replace (j mod n + (o <> O) -> (m <> O) -> + A ⊗ (@e_i n k) ⊤ ≡ (fun i j => + if (j mod n =? k) then A i (j / n)%nat else 0%R). +Proof. + intros. + rewrite kron_e_i_transpose_r; easy. +Qed. + +Lemma kron_e_i_transpose_r_mat_equiv' : forall k n m o (A : Matrix m o), (k < n)%nat -> + A ⊗ (@e_i n k) ⊤ ≡ (fun i j => + if (j mod n =? k) then A i (j / n)%nat else 0%R). +Proof. + intros. + destruct m; [|destruct o]; + try (intros i j Hi Hj; lia). + rewrite kron_e_i_transpose_r; easy. +Qed. + +Lemma ei_kron_I_kron_ei : forall m n k, (k < n)%nat -> m <> O -> + (@e_i n k) ⊤ ⊗ (Matrix.I m) ⊗ (@e_i n k) = + (fun i j => if (i mod n =? k) && (j / m =? k)%nat + && (i / n =? j - k * m) && (i / n m <> O -> + (@e_i n k) ⊤ ⊗ (Matrix.I m) ⊗ (@e_i n k) ≡ + (fun i j => if (i mod n =? k) && (j / m =? k)%nat + && (i / n =? j - k * m) && (i / n + (@e_i n k) ⊤ ⊗ (Matrix.I m) ⊗ (@e_i n k) ≡ + (fun i j => if (i mod n =? k) && (j / m =? k)%nat + && (i / n =? j - k * m) && (i / n + (@e_i n j) ⊤ ⊗ (Matrix.I m) ⊗ (@e_i n j)) n. +Proof. + intros m n. + intros i j Hi Hj. + rewrite Msum_Csum. + erewrite big_sum_eq_bounded. + 2: { + intros ij Hij. + rewrite ei_kron_I_kron_ei by lia. + reflexivity. + } + unfold kron_comm, make_WF. + do 2 simplify_bools_lia_one_kernel. + replace (i / n (@e_i n j) ⊤ ⊗ (Matrix.I m) ⊗ (@e_i n j)) n. +Proof. + intros m n. + apply mat_equiv_eq; + [|eapply WF_Matrix_dim_change; [lia..|]|]; + [auto with wf_db..|]. + apply kron_comm_kron_form_sum_mat_equiv; easy. +Qed. + +Lemma kron_comm_kron_form_sum' : forall m n, + kron_comm m n = big_sum (fun i => + (@e_i m i) ⊗ (Matrix.I n) ⊗ (@e_i m i)⊤) m. +Proof. + intros. + rewrite <- (kron_comm_transpose n m). + rewrite (kron_comm_kron_form_sum n m). + replace (n * m)%nat with (1 * n * m)%nat by lia. + replace (m * n)%nat with (m * n * 1)%nat by lia. + rewrite (Nat.mul_1_r (m * n * 1)). + etransitivity; + [apply Msum_transpose|]. + apply big_sum_eq_bounded. + intros k Hk. + restore_dims. + rewrite !kron_transpose. + now rewrite id_transpose_eq, transpose_involutive. +Qed. + +Lemma kron_comm_kron_form_sum'_mat_equiv : forall m n, + kron_comm m n ≡ big_sum (fun i => + (@e_i m i) ⊗ (Matrix.I n) ⊗ (@e_i m i)⊤) m. +Proof. + intros. + rewrite kron_comm_kron_form_sum'; easy. +Qed. + +Lemma e_i_dot_is_component_mat_equiv : forall p k (x : Vector p), + (k < p)%nat -> + (@e_i p k) ⊤ × x ≡ x k O .* Matrix.I 1. +Proof. + intros p k x Hk. + intros i j Hi Hj; + replace i with O by lia; + replace j with O by lia; + clear i Hi; + clear j Hj. + unfold Mmult, transpose, scale, e_i, Matrix.I. + simpl_bools. + rewrite Cmult_1_r. + apply big_sum_unique. + exists k. + split; [easy|]. + bdestructΩ'simp. + rewrite Cmult_1_l. + split; [easy|]. + intros l Hl Hkl. + bdestructΩ'simp. +Qed. + +Lemma e_i_dot_is_component : forall p k (x : Vector p), + (k < p)%nat -> WF_Matrix x -> + (@e_i p k) ⊤ × x = x k O .* Matrix.I 1. +Proof. + intros p k x Hk HWF. + apply mat_equiv_eq; auto with wf_db. + apply e_i_dot_is_component_mat_equiv; easy. +Qed. + +Lemma kron_e_i_e_i : forall p q k l, + (k < p)%nat -> (l < q)%nat -> + @e_i q l ⊗ @e_i p k = @e_i (p*q) (l*p + k). +Proof. + intros p q k l Hk Hl. + apply functional_extensionality; intro i. + apply functional_extensionality; intro j. + unfold kron, e_i. + rewrite Nat.mod_1_r, Nat.div_1_r. + rewrite if_mult_and. + apply f_equal_if; [|easy..]. + rewrite Nat.eqb_refl, andb_true_r. + destruct (j =? 0); [|rewrite 2!andb_false_r; easy]. + rewrite 2!andb_true_r. + rewrite eq_iff_eq_true, 4!andb_true_iff, 3!Nat.eqb_eq, 3!Nat.ltb_lt. + split. + - intros [[] []]. + rewrite (Nat.div_mod_eq i p). + split; nia. + - intros []. + subst i. + rewrite Nat.div_add_l, Nat.div_small, Nat.add_0_r, + Nat.add_comm, Nat.Div0.mod_add, Nat.mod_small by lia. + easy. +Qed. + +Lemma kron_e_i_e_i_mat_equiv : forall p q k l, + (k < p)%nat -> (l < q)%nat -> + @e_i q l ⊗ @e_i p k ≡ @e_i (p*q) (l*p + k). +Proof. + intros p q k l; intros. + rewrite (kron_e_i_e_i p q); easy. +Qed. + +Lemma kron_e_i_e_i_split : forall p q k, (k < p * q)%nat -> + @e_i (p*q) k = @e_i q (k / p) ⊗ @e_i p (k mod p). +Proof. + intros p q k Hk. + rewrite (kron_e_i_e_i p q) by show_moddy_lt. + rewrite (Nat.div_mod_eq k p) at 1. + f_equal; lia. +Qed. + +Lemma kron_eq_sum_mat_equiv : forall p q (x : Vector q) (y : Vector p), + y ⊗ x ≡ big_sum (fun ij => + let i := (ij / q)%nat in let j := ij mod q in + (x j O * y i O) .* (@e_i p i ⊗ @e_i q j)) (p * q). +Proof. + intros p q x y. + erewrite big_sum_eq_bounded. + 2: { + intros ij Hij. + simpl. + rewrite (@kron_e_i_e_i q p) by + (try apply Nat.mod_upper_bound; try apply Nat.Div0.div_lt_upper_bound; lia). + rewrite (Nat.mul_comm (ij / q) q). + rewrite <- (Nat.div_mod_eq ij q). + reflexivity. + } + intros i j Hi Hj. + replace j with O by lia; clear j Hj. + simpl. + rewrite Msum_Csum. + symmetry. + apply big_sum_unique. + exists i. + split; [lia|]. + unfold e_i; split. + - unfold scale, kron; bdestructΩ'simp. + - intros j Hj Hij. + unfold scale, kron; bdestructΩ'simp. +Qed. + +Lemma kron_eq_sum : forall p q (x : Vector q) (y : Vector p), + WF_Matrix x -> WF_Matrix y -> + y ⊗ x = big_sum (fun ij => + let i := (ij / q)%nat in let j := ij mod q in + (x j O * y i O) .* (@e_i p i ⊗ @e_i q j)) (p * q). +Proof. + intros p q x y Hwfx Hwfy. + apply mat_equiv_eq; [| |]; auto with wf_db. + apply kron_eq_sum_mat_equiv. +Qed. + +Lemma kron_comm_commutes_vectors_l_mat_equiv : forall p q (x : Vector q) (y : Vector p), + kron_comm p q × (x ⊗ y) ≡ (y ⊗ x). +Proof. + intros p q x y. + rewrite kron_comm_ei_kron_ei_sum'_mat_equiv, Mmult_Msum_distr_r. + + rewrite (big_sum_mat_equiv_bounded _ + (fun k => x (k mod q) 0 * y (k / q) 0 .* (e_i (k / q) ⊗ e_i (k mod q)))%nat); + [rewrite <- kron_eq_sum_mat_equiv; easy|]. + intros k Hk. + simpl. + rewrite Mmult_assoc. + change 1%nat with (1 * 1)%nat. + restore_dims. + rewrite (kron_transpose' (@e_i q (k mod q)) (@e_i p (k / q))). + rewrite kron_mixed_product. + rewrite 2!(e_i_dot_is_component_mat_equiv) by show_moddy_lt. + rewrite Mscale_kron_dist_l, Mscale_kron_dist_r, Mscale_assoc. + rewrite kron_1_l, Mscale_mult_dist_r, Mmult_1_r by auto with wf_db. + reflexivity. +Qed. + +Lemma kron_comm_commutes_vectors_l : forall p q (x : Vector q) (y : Vector p), + WF_Matrix x -> WF_Matrix y -> + kron_comm p q × (x ⊗ y) = (y ⊗ x). +Proof. + intros p q x y Hwfx Hwfy. + apply mat_equiv_eq; auto with wf_db. + apply kron_comm_commutes_vectors_l_mat_equiv. +Qed. + +(* Lemma kron_basis_vector_basis_vector : forall p q k l, + (k < p)%nat -> (l < q)%nat -> + basis_vector q l ⊗ basis_vector p k = basis_vector (p*q) (l*p + k). +Proof. + intros p q k l Hk Hl. + apply functional_extensionality; intros i. + apply functional_extensionality; intros j. + unfold kron, basis_vector. + rewrite Nat.mod_1_r, Nat.div_1_r, Nat.eqb_refl, andb_true_r, if_mult_and. + pose proof (Nat.div_mod_eq i p). + bdestructΩ'simp. + rewrite Nat.div_add_l, Nat.div_small in * by lia. + lia. +Qed. + +Lemma kron_basis_vector_basis_vector_mat_equiv : forall p q k l, + (k < p)%nat -> (l < q)%nat -> + basis_vector q l ⊗ basis_vector p k ≡ basis_vector (p*q) (l*p + k). +Proof. + intros. + rewrite (kron_basis_vector_basis_vector p q); easy. +Qed. *) + +Lemma kron_extensionality_mat_equiv : forall n m s t (A B : Matrix (n*m) (s*t)), + (forall (x : Vector s) (y :Vector t), + A × (x ⊗ y) ≡ B × (x ⊗ y)) -> + A ≡ B. +Proof. + intros n m s t A B Hext. + apply mat_equiv_of_equiv_on_ei. + intros i Hi. + + pose proof (Nat.Div0.div_lt_upper_bound i t s ltac:(lia)). + pose proof (Nat.mod_upper_bound i s ltac:(lia)). + pose proof (Nat.mod_upper_bound i t ltac:(lia)). + + specialize (Hext (@e_i s (i / t)) (@e_i t (i mod t))). + rewrite (kron_e_i_e_i_mat_equiv t s) in Hext by lia. + rewrite (Nat.mul_comm (i/t) t), <- (Nat.div_mod_eq i t) in Hext. + rewrite (Nat.mul_comm t s) in Hext. easy. +Qed. + +Lemma kron_extensionality : forall n m s t (A B : Matrix (n*m) (s*t)), + WF_Matrix A -> WF_Matrix B -> + (forall (x : Vector s) (y :Vector t), + WF_Matrix x -> WF_Matrix y -> + A × (x ⊗ y) = B × (x ⊗ y)) -> + A = B. +Proof. + intros n m s t A B HwfA HwfB Hext. + apply mat_equiv_eq; [auto with wf_db..|]. + apply mat_equiv_of_equiv_on_ei. + intros k Hk. + rewrite (Nat.mul_comm s t). + rewrite (kron_e_i_e_i_split t s) by lia. + rewrite mat_equiv_eq_iff by + auto using WF_Matrix_dim_change with wf_db zarith. + restore_dims. + apply Hext; auto with wf_db. +Qed. + +Lemma kron_comm_commutes_mat_equiv : forall n s m t + (A : Matrix n s) (B : Matrix m t), + kron_comm m n × (A ⊗ B) × (kron_comm s t) ≡ (B ⊗ A). +Proof. + intros n s m t A B. + apply kron_extensionality_mat_equiv. + intros x y. + rewrite (Mmult_assoc (_ × _)). + rewrite kron_comm_commutes_vectors_l_mat_equiv. + rewrite Mmult_assoc, kron_mixed_product. + rewrite kron_comm_commutes_vectors_l_mat_equiv. + rewrite <- kron_mixed_product. + easy. +Qed. + +Lemma kron_comm_commutes : forall n s m t + (A : Matrix n s) (B : Matrix m t), + WF_Matrix A -> WF_Matrix B -> + kron_comm m n × (A ⊗ B) × (kron_comm s t) = (B ⊗ A). +Proof. + intros n s m t A B HwfA HwfB. + apply kron_extensionality; + auto with wf_db. + intros x y Hwfx Hwfy. + rewrite (Mmult_assoc (_ × _)). + rewrite kron_comm_commutes_vectors_l by easy. + rewrite Mmult_assoc, kron_mixed_product. + rewrite kron_comm_commutes_vectors_l by auto with wf_db. + rewrite <- kron_mixed_product. + easy. +Qed. + +Lemma commute_kron_mat_equiv : forall n s m t + (A : Matrix n s) (B : Matrix m t), + (A ⊗ B) ≡ kron_comm n m × (B ⊗ A) × (kron_comm t s). +Proof. + intros n s m t A B. + now rewrite kron_comm_commutes_mat_equiv. +Qed. + +Lemma commute_kron : forall n s m t + (A : Matrix n s) (B : Matrix m t), + WF_Matrix A -> WF_Matrix B -> + (A ⊗ B) = kron_comm n m × (B ⊗ A) × (kron_comm t s). +Proof. + intros n s m t A B HA HB. + now rewrite kron_comm_commutes. +Qed. + +Lemma kron_comm_mul_inv_mat_equiv : forall p q, + kron_comm p q × kron_comm q p ≡ Matrix.I (p * q). +Proof. + intros p q. + unfold kron_comm. + rewrite 2!make_WF_equiv. + intros i j Hi Hj. + unfold Mmult. + erewrite big_sum_eq_bounded. + 2: { + intros k Hk. + rewrite Nat.eqb_sym. + rewrite Cmult_if_if_1_l. + replace ((k mod p =? i / q) && (k / p =? i mod q) && + ((k / p =? j mod q) && (j / q =? k mod p))) with + ((k mod p =? i / q) && (k / p =? i mod q) && + ((i mod q =? j mod q) && (i / q =? j / q))) by bdestructΩ'. + rewrite <- eqb_iff_div_mod_eqb. + rewrite andb_comm, andb_if. + reflexivity. + } + unfold I. + simplify_bools_lia_one_kernel. + bdestructΩ'. + - apply big_sum_unique. + exists (j / q + j mod q * p)%nat. + split; [show_moddy_lt|]. + rewrite Nat.Div0.mod_add, Nat.div_add, (Nat.div_small (_/_)), Nat.add_0_l, + Nat.mod_small, 2!Nat.eqb_refl + by lia + show_moddy_lt. + simpl_bools. + split; [easy|]. + intros k Hk. + rewrite <- Nat.eqb_neq. + rewrite (eqb_iff_div_mod_eqb p). + rewrite Nat.Div0.mod_add, Nat.div_add, (Nat.div_small (_/_)), Nat.add_0_l, + Nat.mod_small + by lia + show_moddy_lt. + bdestructΩ'. + - now apply (@big_sum_0 C). +Qed. + +Lemma kron_comm_mul_inv : forall p q, + kron_comm p q × kron_comm q p = Matrix.I _. +Proof. + intros p q. + apply mat_equiv_eq; auto with wf_db. + rewrite kron_comm_mul_inv_mat_equiv; easy. +Qed. + +Lemma kron_comm_mul_transpose_r_mat_equiv : forall p q, + kron_comm p q × (kron_comm p q) ⊤ ≡ Matrix.I _. +Proof. + intros p q. + rewrite (kron_comm_transpose p q). + apply kron_comm_mul_inv_mat_equiv. +Qed. + +Lemma kron_comm_mul_transpose_r : forall p q, + kron_comm p q × (kron_comm p q) ⊤ = Matrix.I _. +Proof. + intros p q. + rewrite (kron_comm_transpose p q). + apply kron_comm_mul_inv. +Qed. + +Lemma kron_comm_mul_transpose_l_mat_equiv : forall p q, + (kron_comm p q) ⊤ × kron_comm p q ≡ Matrix.I _. +Proof. + intros p q. + rewrite <- (kron_comm_transpose q p). + rewrite (transpose_involutive _ _ (kron_comm q p)). + apply kron_comm_mul_transpose_r_mat_equiv. +Qed. + +Lemma kron_comm_mul_transpose_l : forall p q, + (kron_comm p q) ⊤ × kron_comm p q = Matrix.I _. +Proof. + intros p q. + rewrite <- (kron_comm_transpose q p). + rewrite (transpose_involutive _ _ (kron_comm q p)). + apply kron_comm_mul_transpose_r. +Qed. + +Lemma kron_comm_commutes_l_mat_equiv : forall n s m t + (A : Matrix n s) (B : Matrix m t), + kron_comm m n × (A ⊗ B) ≡ (B ⊗ A) × (kron_comm t s). +Proof. + intros n s m t A B. + match goal with |- ?A ≡ ?B => + rewrite <- (Mmult_1_r_mat_eq _ _ A), <- (Mmult_1_r_mat_eq _ _ B) + end. + rewrite (Nat.mul_comm t s). + rewrite <- (kron_comm_mul_transpose_r), <- 2!Mmult_assoc. + rewrite (kron_comm_commutes_mat_equiv n s m t). + apply Mmult_simplify_mat_equiv; [|easy]. + rewrite Mmult_assoc. + restore_dims. + rewrite (kron_comm_mul_inv_mat_equiv t s), Mmult_1_r_mat_eq. + easy. +Qed. + +Lemma kron_comm_commutes_l : forall n s m t + (A : Matrix n s) (B : Matrix m t), + WF_Matrix A -> WF_Matrix B -> + kron_comm m n × (A ⊗ B) = (B ⊗ A) × (kron_comm t s). +Proof. + intros n s m t A B HwfA HwfB. + apply mat_equiv_eq; auto with wf_db. + apply kron_comm_commutes_l_mat_equiv. +Qed. + +Lemma kron_comm_commutes_r_mat_equiv : forall n s m t + (A : Matrix n s) (B : Matrix m t), + (A ⊗ B) × kron_comm s t ≡ (kron_comm n m) × (B ⊗ A). +Proof. + intros. + rewrite kron_comm_commutes_l_mat_equiv; easy. +Qed. + +Lemma kron_comm_commutes_r : forall n s m t + (A : Matrix n s) (B : Matrix m t), + WF_Matrix A -> WF_Matrix B -> + (A ⊗ B) × kron_comm s t = (kron_comm n m) × (B ⊗ A). +Proof. + intros n s m t A B HA HB. + rewrite kron_comm_commutes_l; easy. +Qed. + + + +(* Lemma kron_comm_commutes_r : forall n s m t + (A : Matrix n s) (B : Matrix m t), + WF_Matrix A -> WF_Matrix B -> + kron_comm m n × (A ⊗ B) = (B ⊗ A) × (kron_comm t s). +Proof. + intros n s m t A B HwfA HwfB. + match goal with |- ?A = ?B => + rewrite <- (Mmult_1_r _ _ A), <- (Mmult_1_r _ _ B) ; auto with wf_db + end. + rewrite (Nat.mul_comm t s). + rewrite <- (kron_comm_mul_transpose_r), <- 2!Mmult_assoc. + rewrite (kron_comm_commutes n s m t) by easy. + apply Mmult_simplify; [|easy]. + rewrite Mmult_assoc. + rewrite (Nat.mul_comm s t), (kron_comm_mul_inv t s), Mmult_1_r by auto with wf_db. + easy. +Qed. *) + + +Lemma vector_eq_basis_comb_mat_equiv : forall n (y : Vector n), + y ≡ big_sum (fun i => y i O .* @e_i n i) n. +Proof. + intros n y. + intros i j Hi Hj. + replace j with O by lia; clear j Hj. + symmetry. + rewrite Msum_Csum. + apply big_sum_unique. + exists i. + repeat split; try easy. + - unfold ".*", e_i; bdestructΩ'simp. + - intros l Hl Hnk. + unfold ".*", e_i; bdestructΩ'simp. +Qed. + + +Lemma vector_eq_basis_comb : forall n (y : Vector n), + WF_Matrix y -> + y = big_sum (G:=Vector n) (fun i => y i O .* @e_i n i) n. +Proof. + intros n y Hwfy. + apply mat_equiv_eq; auto with wf_db. + apply vector_eq_basis_comb_mat_equiv. +Qed. + +(* Lemma kron_vecT_matrix_vec : forall m n o p + (P : Matrix m o) (y : Vector n) (z : Vector p), + WF_Matrix y -> WF_Matrix z -> WF_Matrix P -> + (z⊤) ⊗ P ⊗ y = @Mmult (m*n) (m*n) (o*p) (kron_comm m n) ((y × (z⊤)) ⊗ P). +Proof. + intros m n o p P y z Hwfy Hwfz HwfP. + match goal with |- ?A = ?B => + rewrite <- (Mmult_1_l _ _ A) ; auto with wf_db + end. + rewrite Nat.mul_1_l. + rewrite <- (kron_comm_mul_transpose_r), Mmult_assoc at 1. + rewrite Nat.mul_1_r, (Nat.mul_comm o p). + apply Mmult_simplify; [easy|]. + rewrite kron_comm_kron_form_sum. + rewrite Msum_transpose. + rewrite Mmult_Msum_distr_r. + erewrite big_sum_eq_bounded. + 2: { + intros k Hk. + pose proof (kron_transpose _ _ _ _ ((@e_i n k) ⊤ ⊗ Matrix.I m) (@e_i n k)) as H; + rewrite Nat.mul_1_l, Nat.mul_1_r, (Nat.mul_comm m n) in *; + rewrite H; clear H. + pose proof (kron_transpose _ _ _ _ ((@e_i n k) ⊤) (Matrix.I m)) as H; + rewrite Nat.mul_1_l in *; + rewrite H; clear H. + restore_dims. + rewrite 2!kron_mixed_product. + rewrite id_transpose_eq, Mmult_1_l by easy. + rewrite e_i_dot_is_component, transpose_involutive by easy. + (* rewrite <- Mmult_transpose. *) + rewrite Mscale_kron_dist_r, <- 2!Mscale_kron_dist_l. + rewrite kron_1_r. + rewrite <- Mscale_mult_dist_l. + reflexivity. + } + rewrite <- (kron_Msum_distr_r n _ P). + rewrite <- (Mmult_Msum_distr_r). + rewrite <- vector_eq_basis_comb by easy. + easy. +Qed. +*) + +Lemma kron_vecT_matrix_vec_mat_equiv : forall m n o p + (P : Matrix m o) (y : Vector n) (z : Vector p), + (z⊤) ⊗ P ⊗ y ≡ kron_comm m n × ((y × (z⊤)) ⊗ P). +Proof. + intros m n o p P y z. + match goal with |- ?A ≡ ?B => + rewrite <- (Mmult_1_l_mat_eq _ _ A) + end. + rewrite Nat.mul_1_l. + rewrite <- (kron_comm_mul_transpose_r_mat_equiv), Mmult_assoc at 1. + rewrite Nat.mul_1_r. + apply Mmult_simplify_mat_equiv; [easy|]. + rewrite kron_comm_kron_form_sum_mat_equiv. + replace (m * n)%nat with (1 * m * n)%nat by lia. + replace (n * m)%nat with (n * m * 1)%nat by lia. + rewrite (Msum_transpose (1*m*n) (n*m*1) n). + restore_dims. + rewrite Mmult_Msum_distr_r. + replace (n * m * 1)%nat with (1 * m * n)%nat by lia. + replace (p * o)%nat with (p * o * 1)%nat by lia. + rewrite (Nat.mul_1_r (p * o * 1)). + erewrite (big_sum_mat_equiv_bounded _ _ n). + 2: { + intros k Hk. + unshelve (instantiate (1:=_)). + refine (fun k : nat => y k 0%nat .* e_i k × (z) ⊤ ⊗ P); exact n. + pose proof (kron_transpose _ _ _ _ ((@e_i n k) ⊤ ⊗ Matrix.I m) (@e_i n k)) as H; + rewrite Nat.mul_1_l, Nat.mul_1_r, (Nat.mul_comm m n) in *; + rewrite H; clear H. + pose proof (kron_transpose _ _ _ _ ((@e_i n k) ⊤) (Matrix.I m)) as H; + rewrite Nat.mul_1_l in *; + rewrite H; clear H. + restore_dims. + rewrite 2!kron_mixed_product. + rewrite (id_transpose_eq m). + rewrite Mscale_mult_dist_l, transpose_involutive. + rewrite <- (kron_1_r _ _ P) at 2. + rewrite Mscale_kron_dist_l, <- !Mscale_kron_dist_r. + rewrite kron_assoc_mat_equiv. + restore_dims. + apply kron_simplify_mat_equiv; [easy|]. + rewrite <- Mscale_kron_dist_r. + rewrite Mmult_1_l_mat_eq. + apply kron_simplify_mat_equiv; [easy|]. + rewrite (e_i_dot_is_component_mat_equiv); easy. + } + rewrite <- (kron_Msum_distr_r n _ P). + rewrite <- (Mmult_Msum_distr_r). + replace (1*m*n)%nat with (n*m)%nat by lia. + replace (p*o*1)%nat with (p*o)%nat by lia. + apply kron_simplify_mat_equiv; [|easy]. + apply Mmult_simplify_mat_equiv; [|easy]. + symmetry. + apply vector_eq_basis_comb_mat_equiv. +Qed. + +Lemma kron_vecT_matrix_vec : forall m n o p + (P : Matrix m o) (y : Vector n) (z : Vector p), + WF_Matrix y -> WF_Matrix z -> WF_Matrix P -> + (z⊤) ⊗ P ⊗ y = kron_comm m n × ((y × (z⊤)) ⊗ P). +Proof. + intros m n o p P y z Hwfy Hwfz HwfP. + apply mat_equiv_eq; + [|rewrite ?Nat.mul_1_l, ?Nat.mul_1_r; apply WF_mult|]; + auto with wf_db. + apply kron_vecT_matrix_vec_mat_equiv. +Qed. + +Lemma kron_vec_matrix_vecT : forall m n o p + (Q : Matrix n o) (x : Vector m) (z : Vector p), + WF_Matrix x -> WF_Matrix z -> WF_Matrix Q -> + x ⊗ Q ⊗ (z⊤) = kron_comm m n × (Q ⊗ (x × z⊤)). +Proof. + intros m n o p Q x z Hwfx Hwfz HwfQ. + match goal with |- ?A = ?B => + rewrite <- (Mmult_1_l _ _ A) ; auto with wf_db + end. + rewrite Nat.mul_1_r. + rewrite <- (kron_comm_mul_transpose_r), Mmult_assoc at 1. + rewrite Nat.mul_1_l. + apply Mmult_simplify; [easy|]. + rewrite kron_comm_kron_form_sum'. + rewrite (Msum_transpose (m*n) (n*m) m). + restore_dims. + rewrite Mmult_Msum_distr_r. + erewrite big_sum_eq_bounded. + 2: { + intros k Hk. + restore_dims. + replace (@transpose (m*n) (n*m)) with + (@transpose (m*n*1) (1*n*m)) by (f_equal; lia). + rewrite kron_transpose. + rewrite kron_transpose, transpose_involutive. + restore_dims. + rewrite 2!kron_mixed_product. + rewrite id_transpose_eq, Mmult_1_l by easy. + rewrite e_i_dot_is_component, transpose_involutive by easy. + rewrite 2!Mscale_kron_dist_l, kron_1_l, <-Mscale_kron_dist_r by easy. + rewrite <- Mscale_mult_dist_l. + restore_dims. + reflexivity. + } + erewrite big_sum_eq_bounded. + 2: { + intros k Hk. + rewrite transpose_involutive. + reflexivity. + } + rewrite <- (kron_Msum_distr_l m _ Q). + rewrite <- (Mmult_Msum_distr_r). + rewrite <- vector_eq_basis_comb by easy. + easy. +Qed. + +Lemma kron_vec_matrix_vecT_mat_equiv : forall m n o p + (Q : Matrix n o) (x : Vector m) (z : Vector p), + x ⊗ Q ⊗ (z⊤) ≡ kron_comm m n × (Q ⊗ (x × z⊤)). +Proof. + intros m n o p Q x z. + match goal with |- ?A ≡ ?B => + rewrite <- (Mmult_1_l_mat_eq _ _ A) + end. + rewrite Nat.mul_1_r. + rewrite <- (kron_comm_mul_transpose_r_mat_equiv), Mmult_assoc at 1. + rewrite Nat.mul_1_l. + apply Mmult_simplify_mat_equiv; [easy|]. + rewrite kron_comm_kron_form_sum'. + replace (@transpose (m*n) (n*m)) with + (@transpose (m*n*1) (1*n*m)) by (f_equal; lia). + rewrite (Msum_transpose (m*n*1) (1*n*m) m). + restore_dims. + rewrite Mmult_Msum_distr_r. + replace (@mat_equiv (n*m) (o*p)) + with (@mat_equiv (m*n*1) (1*o*p)) by (f_equal; lia). + erewrite (big_sum_mat_equiv_bounded). + 2: { + intros k Hk. + unshelve (instantiate (1:=(fun k : nat => + @kron n o m p Q + (@Mmult m 1 p (@scale m 1 (x k 0%nat) (@e_i m k)) + (@transpose p 1 z))))). + rewrite 2!kron_transpose. + restore_dims. + rewrite 2!kron_mixed_product. + rewrite id_transpose_eq, transpose_involutive. + rewrite Mscale_mult_dist_l, Mscale_kron_dist_r, <- Mscale_kron_dist_l. + replace (m*n*1)%nat with (1*n*m)%nat by lia. + replace (@kron n o m p) with (@kron (1*n) (1*o) m p) by (f_equal; lia). + apply kron_simplify_mat_equiv; [|easy]. + intros i j Hi Hj. + unfold kron. + rewrite (Mmult_1_l_mat_eq _ _ Q) by (apply Nat.mod_upper_bound; lia). + (* revert i j Hi Hj. *) + rewrite (e_i_dot_is_component_mat_equiv m k x Hk) by (apply Nat.Div0.div_lt_upper_bound; lia). + set (a:= (@kron 1 1 n o ((x k 0%nat .* Matrix.I 1)) Q) i j). + match goal with + |- ?A = _ => change A with a + end. + unfold a. + clear a. + rewrite Mscale_kron_dist_l. + unfold scale. + rewrite kron_1_l_mat_equiv by lia. + easy. + } + rewrite <- (kron_Msum_distr_l m _ Q). + rewrite <- (Mmult_Msum_distr_r). + rewrite (Nat.mul_comm m n). + rewrite Nat.mul_1_r, Nat.mul_1_l. + rewrite <- vector_eq_basis_comb_mat_equiv. + easy. +Qed. + +Lemma kron_comm_triple_cycle_mat : forall m n s t p q (A : Matrix m n) + (B : Matrix s t) (C : Matrix p q), + A ⊗ B ⊗ C ≡ (kron_comm (m*s) p) × (C ⊗ A ⊗ B) × (kron_comm q (t*n)). +Proof. + intros m n s t p q A B C. + rewrite (commute_kron_mat_equiv _ _ _ _ (A ⊗ B) C) by auto with wf_db. + rewrite (Nat.mul_comm n t), (Nat.mul_comm q (t*n)). + apply Mmult_simplify_mat_equiv; [|easy]. + apply Mmult_simplify_mat_equiv; [easy|]. + rewrite (Nat.mul_comm t n). + intros i j Hi Hj; + rewrite <- (kron_assoc_mat_equiv C A B); + [easy|lia|lia]. +Qed. + +Lemma kron_comm_triple_cycle : forall m n s t p q (A : Matrix m n) + (B : Matrix s t) (C : Matrix p q), WF_Matrix A -> WF_Matrix B -> WF_Matrix C -> + A ⊗ B ⊗ C = (kron_comm (m*s) p) × (C ⊗ A ⊗ B) × (kron_comm q (t*n)). +Proof. + intros m n s t p q A B C HA HB HC. + rewrite (commute_kron _ _ _ _ (A ⊗ B) C) by auto with wf_db. + rewrite kron_assoc by easy. + f_equal; try lia; f_equal; lia. +Qed. + +Lemma kron_comm_triple_cycle2_mat_equiv : forall m n s t p q (A : Matrix m n) + (B : Matrix s t) (C : Matrix p q), + A ⊗ (B ⊗ C) ≡ (kron_comm m (s*p)) × (B ⊗ C ⊗ A) × (kron_comm (q*t) n). +Proof. + intros m n s t p q A B C. + rewrite kron_assoc_mat_equiv. + intros i j Hi Hj. + rewrite (commute_kron_mat_equiv _ _ _ _ A (B ⊗ C)) by lia. + rewrite (Nat.mul_comm t q). + apply Mmult_simplify_mat_equiv; [|easy + lia..]. + apply Mmult_simplify_mat_equiv; [easy|]. + rewrite (Nat.mul_comm q t). + apply kron_assoc_mat_equiv. +Qed. + +Lemma kron_comm_triple_cycle2 : forall m n s t p q (A : Matrix m n) + (B : Matrix s t) (C : Matrix p q), WF_Matrix A -> WF_Matrix B -> WF_Matrix C -> + A ⊗ (B ⊗ C) = (kron_comm m (s*p)) × (B ⊗ C ⊗ A) × (kron_comm (q*t) n). +Proof. + intros m n s t p q A B C HA HB HC. + apply mat_equiv_eq; [auto using WF_Matrix_dim_change with wf_db zarith..|]. + apply kron_comm_triple_cycle2_mat_equiv. +Qed. + + + + + + + +Lemma id_eq_sum_kron_e_is_mat_equiv : forall n, + Matrix.I n ≡ big_sum (G:=Square n) (fun i => @e_i n i ⊗ (@e_i n i) ⊤) n. +Proof. + intros n. + symmetry. + intros i j Hi Hj. + rewrite Msum_Csum. + erewrite big_sum_eq_bounded. + 2: { + intros k Hk. + rewrite kron_e_i_l by lia. + unfold transpose, e_i. + rewrite <- andb_if. + replace_bool_lia (j @e_i n i ⊗ (@e_i n i) ⊤) n. +Proof. + intros n. + apply mat_equiv_eq; auto with wf_db. + apply id_eq_sum_kron_e_is_mat_equiv. +Qed. + +Lemma kron_comm_cycle_indices : forall t s n, + kron_comm (t*s) n = @Mmult (s*(n*t)) (s*(n*t)) (t*(s*n)) + (kron_comm s (n*t)) (kron_comm t (s*n)). +Proof. + intros t s n. + rewrite kron_comm_kron_form_sum. + erewrite big_sum_eq_bounded. + 2: { + intros j Hj. + rewrite (Nat.mul_comm t s), <- id_kron, <- kron_assoc by auto with wf_db. + restore_dims. + rewrite kron_assoc by auto with wf_db. + (* rewrite (kron_assoc ((@e_i n j)⊤ ⊗ Matrix.I t) (Matrix.I s) (@e_i n j)) by auto with wf_db. *) + lazymatch goal with + |- ?A ⊗ ?B = _ => rewrite (commute_kron _ _ _ _ A B) by auto with wf_db + end. + (* restore_dims. *) + reflexivity. + } + (* rewrite ?Nat.mul_1_r, ?Nat.mul_1_l. *) + (* rewrite <- Mmult_Msum_distr_r. *) + + rewrite <- (Mmult_Msum_distr_r n _ (kron_comm (t*1) (n*s))). + rewrite <- Mmult_Msum_distr_l. + erewrite big_sum_eq_bounded. + 2: { + intros j Hj. + rewrite <- kron_assoc, (kron_assoc (Matrix.I t)) by auto with wf_db. + restore_dims. + reflexivity. + } + (* rewrite Nat.mul_1_l *) + rewrite <- (kron_Msum_distr_r n _ (Matrix.I s)). + rewrite <- (kron_Msum_distr_l n _ (Matrix.I t)). + rewrite 2!Nat.mul_1_r, 2!Nat.mul_1_l. + rewrite <- (id_eq_sum_kron_e_is n). + rewrite 2!id_kron. + restore_dims. + rewrite Mmult_1_r by auto with wf_db. + rewrite (Nat.mul_comm t n), (Nat.mul_comm n s). + easy. +Qed. + +Lemma kron_comm_cycle_indices_mat_equiv : forall t s n, + (kron_comm (t*s) n ≡ @Mmult (s*(n*t)) (s*(n*t)) (t*(s*n)) (kron_comm s (n*t)) (kron_comm t (s*n))). +Proof. + intros t s n. + rewrite kron_comm_cycle_indices; easy. +Qed. + +Lemma kron_comm_cycle_indices_rev : forall t s n, + @Mmult (s*(n*t)) (s*(n*t)) (t*(s*n)) (kron_comm s (n*t)) (kron_comm t (s*n)) = kron_comm (t*s) n. +Proof. + intros. + rewrite <- kron_comm_cycle_indices. + easy. +Qed. + +Lemma kron_comm_cycle_indices_rev_mat_equiv : forall t s n, + @Mmult (s*(n*t)) (s*(n*t)) (t*(s*n)) (kron_comm s (n*t)) (kron_comm t (s*n)) ≡ kron_comm (t*s) n. +Proof. + intros. + rewrite <- kron_comm_cycle_indices. + easy. +Qed. + +Lemma kron_comm_triple_id : forall t s n, + (kron_comm (t*s) n) × (kron_comm (s*n) t) × (kron_comm (n*t) s) = Matrix.I (t*s*n). +Proof. + intros t s n. + rewrite kron_comm_cycle_indices. + restore_dims. + rewrite (Mmult_assoc (kron_comm s (n*t))). + restore_dims. + rewrite (kron_comm_mul_inv t (s*n)). + restore_dims. + rewrite Mmult_1_r by auto with wf_db. + rewrite (kron_comm_mul_inv). + f_equal; lia. +Qed. + +Lemma kron_comm_triple_id_mat_equiv : forall t s n, + (kron_comm (t*s) n) × (kron_comm (s*n) t) × (kron_comm (n*t) s) ≡ Matrix.I (t*s*n). +Proof. + intros t s n. + setoid_rewrite kron_comm_triple_id; easy. +Qed. + +Lemma kron_comm_triple_id' : forall n t s, + (kron_comm n (t*s)) × (kron_comm t (s*n)) × (kron_comm s (n*t)) = Matrix.I (t*s*n). +Proof. + intros n t s. + apply transpose_matrices. + rewrite 2!Mmult_transpose. + rewrite (kron_comm_transpose s (n*t)). + rewrite (kron_comm_transpose n (t*s)). + restore_dims. + rewrite (Nat.mul_assoc s n t), <- (Nat.mul_assoc t s n). + + rewrite (kron_comm_transpose t (s*n)). + rewrite Nat.mul_assoc. + replace (t*(s*n))%nat with (n*t*s)%nat by lia. + rewrite id_transpose_eq. + replace (n*t*s)%nat with (t*n*s)%nat by lia. + rewrite <- (kron_comm_triple_id t n s). + rewrite Mmult_assoc. + restore_dims. + replace (s*(t*n))%nat with (s*(n*t))%nat by lia. + replace (n*(t*s))%nat with (n*(s*t))%nat by lia. + replace (n*t*s)%nat with (t*n*s)%nat by lia. + apply Mmult_simplify; [f_equal; lia|]. + repeat (f_equal; try lia). +Qed. + +Lemma kron_comm_triple_id'_mat_equiv : forall t s n, + (kron_comm n (t*s)) × (kron_comm t (s*n)) × (kron_comm s (n*t)) = Matrix.I (t*s*n). +Proof. + intros t s n. + rewrite (kron_comm_triple_id' n t s). + easy. +Qed. + +Lemma kron_comm_triple_id'C : forall n t s, + (kron_comm n (s*t)) × (kron_comm t (n*s)) × (kron_comm s (t*n)) = Matrix.I (t*s*n). +Proof. + intros n t s. + rewrite <- (kron_comm_triple_id' n t s). + rewrite (Nat.mul_comm s t), (Nat.mul_comm n s), + (Nat.mul_comm t n). + easy. +Qed. + +Lemma kron_comm_triple_id'C_mat_equiv : forall n t s, + (kron_comm n (s*t)) × (kron_comm t (n*s)) × (kron_comm s (t*n)) ≡ Matrix.I (t*s*n). +Proof. + intros n t s. + rewrite <- (kron_comm_triple_id'C n t s). + easy. +Qed. + +Lemma kron_comm_triple_indices_collapse_mat_equiv : forall s n t, + @Mmult (s*(n*t)) (s*(n*t)) (t*(s*n)) (kron_comm s (n*t)) (kron_comm t (s*n)) + ≡ (kron_comm (t*s) n). +Proof. + intros s n t. + rewrite <- (Mmult_1_r_mat_eq _ _ (_ × _)). + (* replace (t*(s*n))%nat with (n*(t*s))%nat by lia. *) + rewrite <- (kron_comm_mul_inv_mat_equiv). + rewrite <- Mmult_assoc. + (* restore_dims. *) + pose proof (kron_comm_triple_id'C s t n) as Hrw. + apply (f_equal (fun A => A × kron_comm (t*s) n)) in Hrw. + replace (t*n*s)%nat with (t*s*n)%nat in Hrw by lia. + restore_dims in Hrw. + rewrite (Mmult_1_l _ _ (kron_comm (t*s) n)) in Hrw by auto with wf_db. + rewrite <- Hrw. + rewrite !Mmult_assoc. + restore_dims. + replace (n*(t*s))%nat with (t*(s*n))%nat by lia. + apply Mmult_simplify_mat_equiv; [easy|]. + replace (n*t*s)%nat with (t*(s*n))%nat by lia. + apply Mmult_simplify_mat_equiv; [easy|]. + restore_dims. + rewrite 2!kron_comm_mul_inv. + now replace (t*(s*n))%nat with (n*(t*s))%nat by lia. +Qed. + +Lemma kron_comm_triple_indices_collapse : forall s n t, + @Mmult (s*(n*t)) (s*(n*t)) (t*(s*n)) (kron_comm s (n*t)) (kron_comm t (s*n)) + = (kron_comm (t*s) n). +Proof. + intros s n t. + apply mat_equiv_eq; + [restore_dims; apply WF_Matrix_dim_change; [lia..|]..|]; + auto 10 using WF_Matrix_dim_change with wf_db zarith. + apply kron_comm_triple_indices_collapse_mat_equiv. +Qed. + +Lemma kron_comm_triple_indices_collapse_mat_equivC : forall s n t, + @Mmult (s*(t*n)) (s*(t*n)) (t*(n*s)) (kron_comm s (t*n)) (kron_comm t (n*s)) + ≡ (kron_comm (t*s) n). +Proof. + intros s n t. + rewrite (Nat.mul_comm t n), (Nat.mul_comm n s). + rewrite kron_comm_triple_indices_collapse_mat_equiv. + easy. +Qed. + +Lemma kron_comm_triple_indices_collapseC : forall s n t, + @Mmult (s*(t*n)) (s*(t*n)) (t*(n*s)) (kron_comm s (t*n)) (kron_comm t (n*s)) + = (kron_comm (t*s) n). +Proof. + intros s n t. + apply mat_equiv_eq; + [restore_dims; apply WF_Matrix_dim_change; [lia..|]..|]; + auto 10 using WF_Matrix_dim_change with wf_db zarith. + apply kron_comm_triple_indices_collapse_mat_equivC. +Qed. + +(* +Not sure what this is, or if it's true: +Lemma kron_comm_triple_indices_commute : forall t s n, + @Mmult (s*t*n) (s*t*n) (t*(s*n)) (kron_comm (s*t) n) (kron_comm t (s*n)) = + @Mmult (t*(s*n)) (t*(s*n)) (s*t*n) (kron_comm t (s*n)) (kron_comm (s*t) n). *) +Lemma kron_comm_triple_indices_commute_mat_equiv : forall t s n, + @Mmult (s*(n*t)) (s*(n*t)) (t*(s*n)) (kron_comm s (n*t)) (kron_comm t (s*n)) ≡ + @Mmult (t*(s*n)) (t*(s*n)) (s*(n*t)) (kron_comm t (s*n)) (kron_comm s (n*t)). +Proof. + intros t s n. + rewrite kron_comm_triple_indices_collapse_mat_equiv. + rewrite (Nat.mul_comm t s). + rewrite <- (kron_comm_triple_indices_collapseC t n s). + easy. +Qed. + +Lemma kron_comm_triple_indices_commute : forall t s n, + @Mmult (s*(n*t)) (s*(n*t)) (t*(s*n)) (kron_comm s (n*t)) (kron_comm t (s*n)) = + @Mmult (t*(s*n)) (t*(s*n)) (s*(n*t)) (kron_comm t (s*n)) (kron_comm s (n*t)). +Proof. + intros t s n. + apply mat_equiv_eq; + [restore_dims; apply WF_Matrix_dim_change; [lia..|]..|]; + auto 10 using WF_Matrix_dim_change with wf_db zarith. + apply kron_comm_triple_indices_commute_mat_equiv. +Qed. + +Lemma kron_comm_triple_indices_commute_mat_equivC : forall t s n, + @Mmult (s*(t*n)) (s*(t*n)) (t*(n*s)) (kron_comm s (t*n)) (kron_comm t (n*s)) ≡ + @Mmult (t*(s*n)) (t*(s*n)) (s*(n*t)) (kron_comm t (s*n)) (kron_comm s (n*t)). +Proof. + intros t s n. + rewrite (Nat.mul_comm t n), (Nat.mul_comm n s). + apply kron_comm_triple_indices_commute_mat_equiv. +Qed. + +Lemma kron_comm_triple_indices_commuteC : forall t s n, + @Mmult (s*(t*n)) (s*(t*n)) (t*(n*s)) (kron_comm s (t*n)) (kron_comm t (n*s)) = + @Mmult (t*(s*n)) (t*(s*n)) (s*(n*t)) (kron_comm t (s*n)) (kron_comm s (n*t)). +Proof. + intros t s n. + rewrite (Nat.mul_comm t n), (Nat.mul_comm n s). + apply kron_comm_triple_indices_commute. +Qed. + +Lemma kron_comm_kron_of_mult_commute1_mat_equiv : forall m n p q s t + (A : Matrix m n) (B : Matrix p q) (C : Matrix q s) (D : Matrix n t), + @mat_equiv (m*p) (s*t) ((kron_comm m p) × ((B × C) ⊗ (A × D))) + ((A ⊗ B) × kron_comm n q × (C ⊗ D)). +Proof. + intros m n p q s t A B C D. + rewrite <- kron_mixed_product. + rewrite (Nat.mul_comm p m), <- Mmult_assoc. + rewrite kron_comm_commutes_r_mat_equiv. + match goal with (* TODO: Make a lemma *) + |- ?A ≡ ?B => enough (H : A = B) by (rewrite H; easy) + end. + repeat (f_equal; try lia). +Qed. + +Lemma kron_comm_kron_of_mult_commute2_mat_equiv : forall m n p q s t + (A : Matrix m n) (B : Matrix p q) (C : Matrix q s) (D : Matrix n t), + ((A ⊗ B) × kron_comm n q × (C ⊗ D)) ≡ (A × D ⊗ (B × C)) × kron_comm t s. +Proof. + intros m n p q s t A B C D. + rewrite Mmult_assoc, kron_comm_commutes_l_mat_equiv, <-Mmult_assoc, + <- kron_mixed_product. + easy. +Qed. + +Lemma kron_comm_kron_of_mult_commute3_mat_equiv : forall m n p q s t + (A : Matrix m n) (B : Matrix p q) (C : Matrix q s) (D : Matrix n t), + (A × D ⊗ (B × C)) × kron_comm t s ≡ + (Matrix.I m) ⊗ (B × C) × kron_comm m s × (Matrix.I s ⊗ (A × D)). +Proof. + intros m n p q s t A B C D. + rewrite <- 2!kron_comm_commutes_l_mat_equiv, Mmult_assoc. + restore_dims. + rewrite kron_mixed_product. + rewrite Mmult_1_r_mat_eq, Mmult_1_l_mat_eq. + easy. +Qed. + +Lemma kron_comm_kron_of_mult_commute4_mat_equiv : forall m n p q s t + (A : Matrix m n) (B : Matrix p q) (C : Matrix q s) (D : Matrix n t), + @mat_equiv (m*p) (s*t) + ((Matrix.I m) ⊗ (B × C) × kron_comm m s × (Matrix.I s ⊗ (A × D))) + ((A × D) ⊗ (Matrix.I p) × kron_comm t p × ((B × C) ⊗ Matrix.I t)). +Proof. + intros m n p q s t A B C D. + rewrite <- 2!kron_comm_commutes_l_mat_equiv, 2!Mmult_assoc. + restore_dims. + rewrite 2!kron_mixed_product. + rewrite (Nat.mul_comm m p), 2!Mmult_1_r_mat_eq. + rewrite 2!Mmult_1_l_mat_eq. + easy. +Qed. \ No newline at end of file diff --git a/Matrix.v b/Matrix.v index 5ea0453..ebad47d 100644 --- a/Matrix.v +++ b/Matrix.v @@ -6,7 +6,8 @@ Require Import String. Require Import Program. Require Export Complex. Require Import List. - +Require Import Setoid. +Require Import Modulus. (* TODO: Use matrix equality everywhere, declare equivalence relation *) @@ -67,8 +68,8 @@ Definition mat_equiv {m n : nat} (A B : Matrix m n) : Prop := Infix "==" := mat_equiv (at level 70) : matrix_scope. Lemma mat_equiv_refl : forall m n (A : Matrix m n), mat_equiv A A. - Proof. unfold mat_equiv; reflexivity. Qed. + Lemma mat_equiv_eq : forall {m n : nat} (A B : Matrix m n), WF_Matrix A -> WF_Matrix B -> @@ -85,6 +86,19 @@ Proof. + rewrite WFA, WFB; trivial; left; try lia. Qed. +Lemma WF_Matrix_dim_change : forall (m n m' n' : nat) (A : Matrix m n), + m = m' -> + n = n' -> + @WF_Matrix m n A -> + @WF_Matrix m' n' A. +Proof. intros. subst. easy. Qed. + +(** Equality via bounded equality for WF matrices **) +Ltac prep_matrix_equivalence := + apply mat_equiv_eq; + [solve [auto 100 with wf_db | + auto 100 using WF_Matrix_dim_change with wf_db zarith]..|]. + (** Printing *) Parameter print_C : C -> string. @@ -127,6 +141,22 @@ Proof. simpl; lia. Qed. +Lemma show_WF_list2D_to_matrix m n li : + length li = m -> + forallb (fun x => length x =? n) li = true -> + @WF_Matrix m n (list2D_to_matrix li). +Proof. + intros Hlen. + rewrite forallb_forall. + intros Hin. + setoid_rewrite Nat.eqb_eq in Hin. + apply WF_list2D_to_matrix. + easy. + intros l Hl. + rewrite Hin by easy. + easy. +Qed. + (** Example *) Definition M23 : Matrix 2 3 := fun x y => @@ -162,6 +192,8 @@ Definition Zero {m n : nat} : Matrix m n := fun x y => 0%R. Definition I (n : nat) : Square n := (fun x y => if (x =? y) && (x c). (* in many cases, n needs to be made explicit, but not always, hence it is made implicit here *) Definition e_i {n : nat} (i : nat) : Vector n := @@ -290,7 +322,7 @@ Notation "⨂ A" := (big_kron A) (at level 60): matrix_scope. Notation "n ⨉ A" := (Mmult_n n A) (at level 30, no associativity) : matrix_scope. Notation "⟨ u , v ⟩" := (inner_product u v) (at level 0) : matrix_scope. #[export] Hint Unfold Zero I e_i trace dot Mplus scale Mmult kron mat_equiv transpose - adjoint : U_db. + adjoint const_matrix make_WF : U_db. Ltac destruct_m_1 := match goal with @@ -311,9 +343,8 @@ Ltac solve_end := match goal with | H : lt _ O |- _ => apply Nat.nlt_0_r in H; contradict H end. - -Ltac by_cell := - intros; + +Ltac by_cell_no_intros := let i := fresh "i" in let j := fresh "j" in let Hi := fresh "Hi" in @@ -322,6 +353,10 @@ Ltac by_cell := repeat (destruct i as [|i]; simpl; [|apply <- Nat.succ_lt_mono in Hi]; try solve_end); clear Hi; repeat (destruct j as [|j]; simpl; [|apply <- Nat.succ_lt_mono in Hj]; try solve_end); clear Hj. +Ltac by_cell := + intros; + by_cell_no_intros. + Ltac lma' := apply mat_equiv_eq; repeat match goal with @@ -375,18 +410,224 @@ Lemma Mscale_simplify : forall (n m: nat) (a b : Matrix n m) (c d : C), Proof. intros; subst; easy. Qed. +(** * Proofs about mat_equiv *) + +Lemma mat_equiv_sym : forall {n m : nat} (A B : Matrix n m), + A ≡ B -> B ≡ A. +Proof. + intros n m A B HAB i j Hi Hj. + rewrite HAB by easy. + easy. +Qed. + +Lemma mat_equiv_trans : forall {n m : nat} (A B C : Matrix n m), + A ≡ B -> B ≡ C -> A ≡ C. +Proof. + intros n m A B C HAB HBC i j Hi Hj. + rewrite HAB, HBC by easy. + easy. +Qed. + +#[global] Add Parametric Relation {n m} : (Matrix n m) mat_equiv + reflexivity proved by (mat_equiv_refl _ _) + symmetry proved by (mat_equiv_sym) + transitivity proved by (mat_equiv_trans) + as mat_equiv_rel. + +Lemma mat_equiv_eq_iff {n m} : forall (A B : Matrix n m), + WF_Matrix A -> WF_Matrix B -> A ≡ B <-> A = B. +Proof. + intros; split; try apply mat_equiv_eq; + intros; try subst A; easy. +Qed. + +Lemma Mmult_simplify_mat_equiv : forall {n m o} + (A B : Matrix n m) (C D : Matrix m o), + A ≡ B -> C ≡ D -> A × C ≡ B × D. +Proof. + intros n m o A B C D HAB HCD. + intros i j Hi Hj. + unfold Mmult. + apply big_sum_eq_bounded. + intros k Hk. + rewrite HAB, HCD by easy. + easy. +Qed. + +Add Parametric Morphism {n m o} : (@Mmult n m o) + with signature (@mat_equiv n m) ==> (@mat_equiv m o) ==> (@mat_equiv n o) + as mmult_mat_equiv_morph. +Proof. intros; apply Mmult_simplify_mat_equiv; easy. Qed. + +Lemma kron_simplify_mat_equiv {n m o p} : forall (A B : Matrix n m) + (C D : Matrix o p), A ≡ B -> C ≡ D -> A ⊗ C ≡ B ⊗ D. +Proof. + intros A B C D HAB HCD i j Hi Hj. + unfold kron. + rewrite HAB, HCD; try easy. + 1,2: apply Nat.mod_upper_bound; lia. + 1,2: apply Nat.Div0.div_lt_upper_bound; lia. +Qed. + +Add Parametric Morphism {n m o p} : (@kron n m o p) + with signature (@mat_equiv n m) ==> (@mat_equiv o p) + ==> (@mat_equiv (n*o) (m*p)) as kron_mat_equiv_morph. +Proof. intros; apply kron_simplify_mat_equiv; easy. Qed. + +Lemma Mplus_simplify_mat_equiv : forall {n m} + (A B C D : Matrix n m), + A ≡ B -> C ≡ D -> A .+ C ≡ B .+ D. +Proof. + intros n m A B C D HAB HCD. + intros i j Hi Hj; unfold ".+"; + rewrite HAB, HCD; try easy. +Qed. + +Add Parametric Morphism {n m} : (@Mplus n m) + with signature (@mat_equiv n m) ==> (@mat_equiv n m) ==> (@mat_equiv n m) + as Mplus_mat_equiv_morph. +Proof. intros; apply Mplus_simplify_mat_equiv; easy. Qed. + +Lemma scale_simplify_mat_equiv : forall {n m} + (x y : C) (A B : Matrix n m), + x = y -> A ≡ B -> x .* A ≡ y .* B. +Proof. + intros n m x y A B Hxy HAB i j Hi Hj. + unfold scale. + rewrite Hxy, HAB; easy. +Qed. + +Add Parametric Morphism {n m} : (@scale n m) + with signature (@eq C) ==> (@mat_equiv n m) ==> (@mat_equiv n m) + as scale_mat_equiv_morph. +Proof. intros; apply scale_simplify_mat_equiv; easy. Qed. + +Lemma Mopp_simplify_mat_equiv : forall {n m} (A B : Matrix n m), + A ≡ B -> Mopp A ≡ Mopp B. +Proof. + intros n m A B HAB i j Hi Hj. + unfold Mopp, scale. + rewrite HAB; easy. +Qed. + +Add Parametric Morphism {n m} : (@Mopp n m) + with signature (@mat_equiv n m) ==> (@mat_equiv n m) + as Mopp_mat_equiv_morph. +Proof. intros; apply Mopp_simplify_mat_equiv; easy. Qed. + +Lemma Mminus_simplify_mat_equiv : forall {n m} + (A B C D : Matrix n m), + A ≡ B -> C ≡ D -> Mminus A C ≡ Mminus B D. +Proof. + intros n m A B C D HAB HCD. + intros i j Hi Hj; unfold Mminus, Mopp, Mplus, scale; + rewrite HAB, HCD; try easy. +Qed. + +Add Parametric Morphism {n m} : (@Mminus n m) + with signature (@mat_equiv n m) ==> (@mat_equiv n m) ==> (@mat_equiv n m) + as Mminus_mat_equiv_morph. +Proof. intros; apply Mminus_simplify_mat_equiv; easy. Qed. + +Lemma dot_simplify_mat_equiv : forall {n} (A B : Vector n) + (C D : Vector n), A ≡ B -> C ≡ D -> dot A C = dot B D. +Proof. + intros n A B C D HAB HCD. + apply big_sum_eq_bounded. + intros k Hk. + rewrite HAB, HCD; unfold "<"%nat; easy. +Qed. + +Add Parametric Morphism {n} : (@dot n) + with signature (@mat_equiv n 1) ==> (@mat_equiv n 1) ==> (@eq C) + as dot_mat_equiv_morph. +Proof. intros; apply dot_simplify_mat_equiv; easy. Qed. + +Lemma transpose_simplify_mat_equiv {n m} : forall (A B : Matrix n m), + A ≡ B -> A ⊤ ≡ B ⊤. +Proof. + intros A B HAB i j Hi Hj. + unfold transpose; auto. +Qed. + +Lemma transpose_simplify_mat_equiv_inv {n m} : forall (A B : Matrix n m), + A ⊤ ≡ B ⊤ -> A ≡ B. +Proof. + intros A B HAB i j Hi Hj. + unfold transpose in *; auto. +Qed. + +Add Parametric Morphism {n m} : (@transpose n m) + with signature (@mat_equiv n m) ==> (@mat_equiv m n) + as transpose_mat_equiv_morph. +Proof. intros; apply transpose_simplify_mat_equiv; easy. Qed. + +Lemma adjoint_simplify_mat_equiv {n m} : forall (A B : Matrix n m), + A ≡ B -> A † ≡ B †. +Proof. + intros A B HAB i j Hi Hj. + unfold adjoint; + rewrite HAB by easy; easy. +Qed. + +Add Parametric Morphism {n m} : (@adjoint n m) + with signature (@mat_equiv n m) ==> (@mat_equiv m n) + as adjoint_mat_equiv_morph. +Proof. intros; apply adjoint_simplify_mat_equiv; easy. Qed. + +Lemma trace_of_mat_equiv : forall n (A B : Square n), + A ≡ B -> trace A = trace B. +Proof. + intros n A B HAB. + (* unfold trace. *) + apply big_sum_eq_bounded; intros i Hi. + rewrite HAB; auto. +Qed. + +Add Parametric Morphism {n} : (@trace n) + with signature (@mat_equiv n n) ==> (eq) + as trace_mat_equiv_morph. +Proof. intros; apply trace_of_mat_equiv; easy. Qed. + +Lemma mat_equiv_equivalence : forall {n m}, + equivalence (Matrix n m) mat_equiv. +Proof. + intros n m. + constructor. + - intros A. apply (mat_equiv_refl). + - intros A; apply mat_equiv_trans. + - intros A; apply mat_equiv_sym. +Qed. + +Lemma big_sum_mat_equiv : forall {o p} (f g : nat -> Matrix o p) + (Eq_on: forall x : nat, f x ≡ g x) (n : nat), big_sum f n ≡ big_sum g n. +Proof. + intros o p f g Eq_on n. + induction n. + - easy. + - simpl. + rewrite IHn, Eq_on; easy. +Qed. + +Add Parametric Morphism {n m} : (@big_sum (Matrix n m) (M_is_monoid n m)) + with signature + (Morphisms.pointwise_relation nat (@mat_equiv n m)) ==> (@eq nat) ==> + (@mat_equiv n m) + as big_sum_mat_equiv_morph. +Proof. intros f g Eq_on k. apply big_sum_mat_equiv; easy. Qed. (** * Proofs about well-formedness **) -Lemma WF_Matrix_dim_change : forall (m n m' n' : nat) (A : Matrix m n), - m = m' -> - n = n' -> - @WF_Matrix m n A -> - @WF_Matrix m' n' A. -Proof. intros. subst. easy. Qed. +Lemma WF_Matrix_dim_change_iff m n m' n' (A : Matrix m n) : + m = m' -> n = n' -> + @WF_Matrix m' n' A <-> WF_Matrix A. +Proof. + intros. + now subst. +Qed. Lemma WF_make_WF : forall {m n} (A : Matrix m n), WF_Matrix (make_WF A). Proof. intros. @@ -396,9 +637,19 @@ Proof. intros. bdestruct (y unify_pows_two : wf_db. (* Utility tactics *) @@ -634,6 +886,14 @@ Ltac solve_wf := collate_wf; easy. (** * Basic matrix lemmas *) +Lemma make_WF_equiv n m (A : Matrix n m) : + make_WF A ≡ A. +Proof. + unfold make_WF. + intros i j Hi Hj. + bdestructΩ'. +Qed. + Lemma mat_equiv_make_WF : forall {m n} (T : Matrix m n), T == make_WF T. Proof. unfold make_WF, mat_equiv; intros. @@ -714,6 +974,67 @@ Proof. easy. Qed. +Lemma trace_0_l : forall (A : Square 0), + trace A = 0. +Proof. + intros A. + unfold trace. + easy. +Qed. + +Lemma trace_0_r : forall n, + trace (@Zero n n) = 0. +Proof. + intros A. + unfold trace. + rewrite big_sum_0; easy. +Qed. + +Lemma trace_mmult_eq_ptwise : forall {n m} (A : Matrix n m) (B : Matrix m n), + trace (A×B) = Σ (fun i => Σ (fun j => A i j * B j i) m) n. +Proof. + reflexivity. +Qed. + +Lemma trace_mmult_comm : forall {n m} (A : Matrix n m) (B : Matrix m n), + trace (A×B) = trace (B×A). +Proof. + intros n m A B. + rewrite 2!trace_mmult_eq_ptwise. + rewrite big_sum_swap_order. + do 2 (apply big_sum_eq_bounded; intros). + apply Cmult_comm. +Qed. + +Lemma trace_transpose : forall {n} (A : Square n), + trace (A ⊤) = trace A. +Proof. + reflexivity. +Qed. + +Lemma trace_kron : forall {n p} (A : Square n) (B : Square p), + trace (A ⊗ B) = trace A * trace B. +Proof. + intros n p A B. + destruct p; + [rewrite Nat.mul_0_r, 2!trace_0_l; lca|]. + unfold trace. + simpl_rewrite big_sum_product; [|easy]. + reflexivity. +Qed. + +Lemma trace_big_sum : forall n k f, + trace (big_sum (G:=Square n) f k) = Σ (fun x => trace (f x)) k. +Proof. + intros n k f. + induction k. + - rewrite trace_0_r; easy. + - rewrite <- 2!big_sum_extend_r, <-IHk. + simpl. + rewrite trace_plus_dist. + easy. +Qed. + Lemma Mplus_0_l : forall (m n : nat) (A : Matrix m n), Zero .+ A = A. Proof. intros. lma. Qed. @@ -833,6 +1154,12 @@ Proof. apply Mmult_1_r_mat_eq. Qed. +Lemma Mmult_1_comm {n m} (A : Matrix n m) (HA : WF_Matrix A) : + I n × A = A × I m. +Proof. + now rewrite Mmult_1_r, Mmult_1_l. +Qed. + (* Cool facts about I∞, not used in the development *) Lemma Mmult_inf_l : forall(m n : nat) (A : Matrix m n), WF_Matrix A -> I∞ × A = A. @@ -878,6 +1205,37 @@ Proof. bdestruct (z =? y); bdestruct (z I 1 ⊗ A = A. Proof. intros m n A WF. - prep_matrix_equality. + apply mat_equiv_eq; + [auto using WF_Matrix_dim_change with wf_db..|]. + apply kron_1_l_mat_equiv. +Qed. + +Lemma kron_1_1_mid_comm {n m} (A : Matrix n 1) (B : Matrix 1 m) + (HA : WF_Matrix A) (HB : WF_Matrix B) : + A ⊗ B = B ⊗ A. +Proof. + apply mat_equiv_eq; [auto with wf_db..|]. + intros i j. unfold kron. - unfold I, kron. - bdestruct (m =? 0). rewrite 2 WF by lia. lca. - bdestruct (n =? 0). rewrite 2 WF by lia. lca. - bdestruct (x / m 0)%nat by lia. clear Eq2. - rewrite Nat.div_small_iff in H1 by lia. - rewrite Cmult_0_l. - destruct WF with (x := x) (y := y). lia. - reflexivity. - + rewrite andb_false_r. - assert (x / m <> 0)%nat by lia. clear Eq1. - rewrite Nat.div_small_iff in H1 by lia. - rewrite Cmult_0_l. - destruct WF with (x := x) (y := y). lia. - reflexivity. + intros Hi Hj. + rewrite !Nat.mod_1_r, !Nat.div_1_r, + !Nat.mod_small, !Nat.div_small by lia. + lca. Qed. +Lemma kron_2_0_mid_comm {n m} (A : Matrix n (2 ^ 0)) (B : Matrix (2 ^ 0) m) + (HA : WF_Matrix A) (HB : WF_Matrix B) : + A ⊗ B = B ⊗ A. +Proof. + now apply kron_1_1_mid_comm. +Qed. + Theorem transpose_involutive : forall (m n : nat) (A : Matrix m n), (A⊤)⊤ = A. Proof. reflexivity. Qed. Theorem adjoint_involutive : forall (m n : nat) (A : Matrix m n), A†† = A. Proof. intros. lma. Qed. +Lemma transpose_matrices : forall {n m} (A B : Matrix n m), + A ⊤ = B ⊤ -> A = B. +Proof. + intros. + rewrite <- transpose_involutive. + rewrite <- H. + rewrite transpose_involutive. + easy. +Qed. + +Lemma adjoint_matrices : forall {n m} (A B : Matrix n m), + A † = B † -> A = B. +Proof. + intros. + rewrite <- adjoint_involutive. + rewrite <- H. + rewrite adjoint_involutive. + easy. +Qed. + Lemma id_transpose_eq : forall n, (I n)⊤ = (I n). Proof. intros n. unfold transpose, I. @@ -1116,6 +1507,16 @@ Proof. intros. apply H. Qed. +Lemma Mscale_inv : forall {n m} (A B : Matrix n m) c, + c <> C0 -> c .* A = B <-> A = (/ c) .* B. +Proof. + intros. + split; intro H0; [rewrite <- H0 | rewrite H0]; + rewrite Mscale_assoc. + - rewrite Cinv_l; [ lma | assumption]. + - rewrite Cinv_r; [ lma | assumption]. +Qed. + Lemma Mscale_plus_distr_l : forall (m n : nat) (x y : C) (A : Matrix m n), (x + y) .* A = x .* A .+ y .* A. Proof. @@ -1211,6 +1612,12 @@ Lemma kron_transpose : forall (m n o p : nat) (A : Matrix m n) (B : Matrix o p ) (A ⊗ B)⊤ = A⊤ ⊗ B⊤. Proof. reflexivity. Qed. +Lemma kron_transpose' [m n o p] (A : Matrix m n) (B : Matrix o p) : + forall mo' mp', + @Matrix.transpose mo' mp' (A ⊗ B) = + (@Matrix.transpose m n A) ⊗ (@Matrix.transpose o p B). +Proof. reflexivity. Qed. + Lemma Mplus_adjoint : forall (m n : nat) (A : Matrix m n) (B : Matrix m n), (A .+ B)† = A† .+ B†. Proof. @@ -1236,6 +1643,103 @@ Proof. intros; lca. Qed. +Lemma direct_sum_adjoint : forall {m n o p : nat} + (A : Matrix m n) (B : Matrix o p), + (A .⊕ B) † = A † .⊕ B †. +Proof. + intros m n o p A B. + prep_matrix_equality. + unfold adjoint, direct_sum. + bdestructΩ'. +Qed. + +Lemma direct_sum_Mmult {m n o p q r} (A : Matrix m n) (B : Matrix n o) + (C : Matrix p q) (D : Matrix q r) : WF_Matrix A -> WF_Matrix B -> + WF_Matrix C -> WF_Matrix D -> + (A × B) .⊕ (C × D) = (A .⊕ C) × (B .⊕ D). +Proof. + intros HA HB HC HD. + assert (HAB : WF_Matrix (A × B)) by auto_wf. + assert (HCD : WF_Matrix (C × D)) by auto_wf. + prep_matrix_equivalence. + intros i j Hi Hj. + unfold direct_sum. + bdestruct (i if k if k - (x - y * z) mod z = x mod z. -Proof. - intros. bdestruct (z =? 0). subst. simpl. lia. - specialize (Nat.sub_add (y * z) x H) as G. - rewrite Nat.add_comm in G. - remember (x - (y * z)) as r. - rewrite <- G. rewrite <- Nat.add_mod_idemp_l by easy. rewrite Nat.mod_mul by easy. - easy. -Qed. - -Lemma mod_product : forall x y z, y <> 0 -> x mod (y * z) mod z = x mod z. -Proof. - intros x y z H. bdestruct (z =? 0). subst. - simpl. try rewrite Nat.mul_0_r. reflexivity. - pattern x at 2. rewrite Nat.mod_eq with (b := y * z) by nia. - replace (y * z * (x / (y * z))) with (y * (x / (y * z)) * z) by lia. - rewrite sub_mul_mod. easy. - replace (y * (x / (y * z)) * z) with (y * z * (x / (y * z))) by lia. - apply Nat.mul_div_le. nia. -Qed. - Lemma kron_assoc_mat_equiv : forall {m n p q r s : nat} (A : Matrix m n) (B : Matrix p q) (C : Matrix r s), (A ⊗ B ⊗ C) == A ⊗ (B ⊗ C). Proof. intros. intros i j Hi Hj. - remember (A ⊗ B ⊗ C) as LHS. - unfold kron. - rewrite (Nat.mul_comm p r) at 1 2. - rewrite (Nat.mul_comm q s) at 1 2. - assert (m * p * r <> 0) by lia. - assert (n * q * s <> 0) by lia. - apply Nat.neq_mul_0 in H as [Hmp Hr]. - apply Nat.neq_mul_0 in Hmp as [Hm Hp]. - apply Nat.neq_mul_0 in H0 as [Hnq Hs]. - apply Nat.neq_mul_0 in Hnq as [Hn Hq]. - rewrite <- 2 Nat.div_div by assumption. + unfold kron. + rewrite 2 mod_product. + rewrite (Nat.mul_comm p r). + rewrite (Nat.mul_comm q s). + rewrite <- 2 Nat.Div0.div_div. rewrite <- 2 div_mod. - rewrite 2 mod_product by assumption. rewrite Cmult_assoc. - subst. reflexivity. Qed. @@ -1363,8 +1818,7 @@ Lemma kron_assoc : forall {m n p q r s : nat} (A ⊗ B ⊗ C) = A ⊗ (B ⊗ C). Proof. intros. - apply mat_equiv_eq; auto with wf_db. - apply WF_kron; auto with wf_db; lia. + apply mat_equiv_eq; auto with wf_db zarith. apply kron_assoc_mat_equiv. Qed. @@ -1375,16 +1829,13 @@ Proof. unfold kron, Mmult. prep_matrix_equality. destruct q. - + simpl. - rewrite Nat.mul_0_r. - simpl. + + rewrite Nat.mul_0_r. rewrite Cmult_0_r. reflexivity. - + rewrite (@big_sum_product Complex.C _ _ _ C_is_ring). + + rewrite (@big_sum_product Complex.C _ _ _ C_is_ring) by easy. apply big_sum_eq. apply functional_extensionality. intros; lca. - lia. Qed. (* Arguments kron_mixed_product [m n o p q r]. *) @@ -1398,6 +1849,33 @@ Lemma kron_mixed_product' : forall (m n n' o p q q' r mp nq or: nat) (@kron m o p r (@Mmult m n o A C) (@Mmult p q r B D)). Proof. intros. subst. apply kron_mixed_product. Qed. +Lemma kron_id_dist_r : forall {n m o} p (A : Matrix n m) (B : Matrix m o), + WF_Matrix A -> WF_Matrix B -> (A × B) ⊗ (I p) = (A ⊗ (I p)) × (B ⊗ (I p)). +Proof. + intros. + now rewrite kron_mixed_product, Mmult_1_r by auto with wf_db. +Qed. + +Lemma kron_id_dist_l : forall {n m o} p (A : Matrix n m) (B : Matrix m o), + WF_Matrix A -> WF_Matrix B -> (I p) ⊗ (A × B) = ((I p) ⊗ A) × ((I p) ⊗ B). +Proof. + intros. + now rewrite kron_mixed_product, Mmult_1_r by auto with wf_db. +Qed. + +Lemma kron_split_diag {n m p q} (A : Matrix n m) (B : Matrix p q) + (HA : WF_Matrix A) (HB : WF_Matrix B) : + A ⊗ B = (A ⊗ I p) × (I m ⊗ B). +Proof. + now rewrite kron_mixed_product, Mmult_1_l, Mmult_1_r. +Qed. + +Lemma kron_split_antidiag {n m p q} (A : Matrix n m) (B : Matrix p q) + (HA : WF_Matrix A) (HB : WF_Matrix B) : + A ⊗ B = (I n ⊗ B) × (A ⊗ I q). +Proof. + now rewrite kron_mixed_product, Mmult_1_l, Mmult_1_r. +Qed. Lemma direct_sum_assoc : forall {m n p q r s : nat} (A : Matrix m n) (B : Matrix p q) (C : Matrix r s), @@ -1442,6 +1920,18 @@ Proof. induction l1. all : assert (H' := H (S i)); simpl in H'; easy. Qed. +Lemma kron_n_1 {n m} (A : Matrix n m) (HA : WF_Matrix A) : + 1 ⨂ A = A. +Proof. + now apply kron_1_l. +Qed. + +Lemma kron_n_S {n m} (A : Matrix n m) k : + (S k) ⨂ A = (k ⨂ A) ⊗ A. +Proof. + easy. +Qed. + Lemma kron_n_assoc : forall n {m1 m2} (A : Matrix m1 m2), WF_Matrix A -> (S n) ⨂ A = A ⊗ (n ⨂ A). Proof. @@ -1495,14 +1985,10 @@ Lemma kron_n_mult : forall {m1 m2 m3} n (A : Matrix m1 m2) (B : Matrix m2 m3), Proof. intros. induction n; simpl. - rewrite Mmult_1_l. reflexivity. - apply WF_I. - replace (m1 * m1 ^ n) with (m1 ^ n * m1) by apply Nat.mul_comm. - replace (m2 * m2 ^ n) with (m2 ^ n * m2) by apply Nat.mul_comm. - replace (m3 * m3 ^ n) with (m3 ^ n * m3) by apply Nat.mul_comm. - rewrite kron_mixed_product. - rewrite IHn. - reflexivity. + - apply Mmult_1_l, WF_I. + - rewrite <- IHn. + rewrite <- kron_mixed_product. + f_equal; apply Nat.mul_comm. Qed. Lemma kron_n_I : forall n, n ⨂ I 2 = I (2 ^ n). @@ -1516,6 +2002,17 @@ Proof. lia. Qed. +Lemma kron_n_I_gen : forall n m, n ⨂ I m = I (m ^ n). +Proof. + intros. + induction n; simpl. + reflexivity. + rewrite IHn. + rewrite id_kron. + apply f_equal. + lia. +Qed. + Lemma Mmult_n_kron_distr_l : forall {m n} i (A : Square m) (B : Square n), i ⨉ (A ⊗ B) = (i ⨉ A) ⊗ (i ⨉ B). Proof. @@ -1632,13 +2129,19 @@ Proof. lma. Qed. +Lemma Msum_transpose : forall n m p f, + (big_sum (G:=Matrix n m) f p) ⊤ = + big_sum (G:=Matrix n m) (fun i => (f i) ⊤) p. +Proof. + intros. + rewrite (big_sum_func_distr f transpose); easy. +Qed. + Lemma Msum_adjoint : forall {d1 d2} n (f : nat -> Matrix d1 d2), (big_sum f n)† = big_sum (fun i => (f i)†) n. Proof. intros. - induction n; simpl. - lma. - rewrite Mplus_adjoint, IHn. + rewrite (big_sum_func_distr f adjoint) by apply Mplus_adjoint. reflexivity. Qed. @@ -1646,23 +2149,30 @@ Lemma Msum_Csum : forall {d1 d2} n (f : nat -> Matrix d1 d2) i j, (big_sum f n) i j = big_sum (fun x => (f x) i j) n. Proof. intros. - induction n; simpl. - reflexivity. - unfold Mplus. - rewrite IHn. + rewrite (big_sum_func_distr f (fun g => g i j)) by easy. reflexivity. Qed. Lemma Msum_plus : forall n {d1 d2} (f g : nat -> Matrix d1 d2), big_sum (fun x => f x .+ g x) n = big_sum f n .+ big_sum g n. Proof. - clear. intros. induction n; simpl. lma. rewrite IHn. lma. Qed. +Lemma Mmult_vec_comm {n} (v u : Vector n) : WF_Matrix u -> WF_Matrix v -> + v ⊤%M × u = u ⊤%M × v. +Proof. + intros Hu Hv. + prep_matrix_equivalence. + by_cell. + apply big_sum_eq_bounded. + intros k Hk. + unfold transpose. + lca. +Qed. (** * Tactics **) @@ -1720,6 +2230,66 @@ Proof. reflexivity. Qed. +Lemma vec_to_list2D_eq {n} (v : Vector n) : WF_Matrix v -> + v = (list2D_to_matrix [vec_to_list v]) ⊤. +Proof. + intros HWF. + pose proof (vec_to_list_length n v) as Hlen. + apply mat_equiv_eq. + - auto_wf. + - cbn. rewrite Hlen. + apply WF_transpose. + apply show_WF_list2D_to_matrix; cbn; rewrite ?Hlen; try easy. + now rewrite Nat.eqb_refl. + - intros i j Hi Hj. + replace j with 0 by lia. + cbn. + now rewrite nth_vec_to_list. +Qed. + +Lemma matrix_eq_list2D_to_matrix {m n} (A : Matrix m n) (HA : WF_Matrix A) : + A = @make_WF m n (list2D_to_matrix ( + map vec_to_list (map (B:=Vector n) + (fun k => fun i j => if j =? 0 then A k i else C0) (seq 0 m)))). +Proof. + prep_matrix_equivalence. + rewrite make_WF_equiv. + intros i j Hi Hj. + unfold list2D_to_matrix. + rewrite (map_nth_small Zero) by now rewrite map_length, seq_length. + rewrite (map_nth_small 0) by now rewrite seq_length. + rewrite nth_vec_to_list by easy. + now rewrite seq_nth by easy. +Qed. + +Lemma list2D_to_matrix_cons l li : + list2D_to_matrix (l :: li) = + list2D_to_matrix [l] .+ + (fun x y => if x =? 0 then C0 else list2D_to_matrix li (x - 1) y). +Proof. + prep_matrix_equality. + autounfold with U_db. + cbn. + destruct x; [lca|]. + replace (S x - 1)%nat with x by lia. + destruct x, y; cbn; now Csimpl. +Qed. + +Lemma Mscale_list2D_to_matrix {n m} c li : + @eq (Matrix n m) (@scale n m c (list2D_to_matrix li)) + (list2D_to_matrix (map (map (Cmult c)) li)). +Proof. + prep_matrix_equality. + autounfold with U_db. + unfold list2D_to_matrix. + change [] with (map (Cmult c) []). + rewrite map_nth. + replace C0 with (c * C0)%C by lca. + rewrite map_nth. + now Csimpl. +Qed. + + (** Restoring Matrix Dimensions *) @@ -1734,8 +2304,8 @@ Ltac is_nat_equality := Ltac unify_matrix_dims tac := try reflexivity; - repeat (apply f_equal_gen; try reflexivity; - try (is_nat_equality; tac)). + repeat (apply f_equal_gen; + try reflexivity; try (is_nat_equality; tac)). Ltac restore_dims_rec A := match A with @@ -1778,6 +2348,12 @@ Ltac restore_dims_rec A := match type of A' with | Matrix ?m' ?n' => constr:(@eq (Matrix m' n') A' B') end + | mat_equiv ?A ?B => + let A' := restore_dims_rec A in + let B' := restore_dims_rec B in + match type of A' with + | Matrix ?m' ?n' => constr:(@mat_equiv m' n' A' B') + end | ?A × ?B => let A' := restore_dims_rec A in let B' := restore_dims_rec B in match type of A' with @@ -1837,19 +2413,96 @@ Ltac restore_dims_rec A := | ?A => A end. -Ltac restore_dims tac := +Ltac restore_dims_using tac := match goal with | |- ?A => let A' := restore_dims_rec A in replace A with A' by unify_matrix_dims tac end. -Tactic Notation "restore_dims" tactic(tac) := restore_dims tac. +Ltac restore_dims_by_exact tac := + match goal with + | |- ?A => let A' := restore_dims_rec A in + replace A with A' by tac + end. + +Ltac restore_dims_tac := + (* Can redefine with: + Ltac restore_dims_tac ::= (tactic). + to extend functionality. *) + (repeat rewrite Nat.pow_1_l; try ring; + unify_pows_two; simpl; lia). -Tactic Notation "restore_dims" := restore_dims (repeat rewrite Nat.pow_1_l; try ring; unify_pows_two; simpl; lia). +Ltac restore_dims := + restore_dims_using restore_dims_tac. +Tactic Notation "restore_dims" "by" tactic(tac) := + restore_dims_using tac. + +Tactic Notation "restore_dims" "in" hyp(H) "by" tactic3(tac) := + match type of H with + | ?A => let A' := restore_dims_rec A in + replace A with A' in H by unify_matrix_dims tac + end. + +Tactic Notation "restore_dims" "in" hyp(H) := + restore_dims in H by restore_dims_tac. + +Tactic Notation "restore_dims" "in" "*" "|-" "by" tactic3(tac) := + multimatch goal with + | H : _ |- _ => try restore_dims in H by tac + end. + +Tactic Notation "restore_dims" "in" "*" "|-" := + restore_dims in * |- by restore_dims_tac. + +Tactic Notation "restore_dims" "in" "*" "by" tactic3(tac) := + restore_dims in * |- by tac; restore_dims by tac. + +Tactic Notation "restore_dims" "in" "*" := + restore_dims in * by restore_dims_tac. (* Proofs depending on restore_dims *) +Lemma kron_n_assoc_mat_equiv : + forall n {m1 m2} (A : Matrix m1 m2), (S n) ⨂ A ≡ A ⊗ (n ⨂ A). +Proof. + intros. induction n. + - simpl. + rewrite kron_1_r. + restore_dims. + now rewrite kron_1_l_mat_equiv. + - cbn [kron_n] in *. + restore_dims in *. + rewrite IHn at 1. + restore_dims. + now rewrite kron_assoc_mat_equiv. +Qed. + +Lemma kron_n_m_split_mat_equiv {o p} : forall n m (A : Matrix o p), + (n + m) ⨂ A ≡ n ⨂ A ⊗ m ⨂ A. +Proof. + induction n. + - simpl. + intros. + symmetry. + restore_dims. + now rewrite kron_1_l_mat_equiv. + - intros. + simpl. + restore_dims. + rewrite IHn. + restore_dims by (rewrite ?Nat.pow_add_r; lia). + rewrite kron_assoc_mat_equiv. + symmetry. + restore_dims. + rewrite kron_assoc_mat_equiv. + pose proof (kron_n_assoc_mat_equiv m A) as H. + symmetry in H. + restore_dims in *. + rewrite H. + simpl. + now restore_dims. +Qed. Lemma kron_n_m_split {o p} : forall n m (A : Matrix o p), WF_Matrix A -> (n + m) ⨂ A = n ⨂ A ⊗ m ⨂ A. @@ -1928,7 +2581,6 @@ Ltac distribute_adjoint := (** Tactics for solving computational matrix equalities **) - (* Construct matrices full of evars *) Ltac mk_evar t T := match goal with _ => evar (t : T) end. @@ -2068,9 +2720,12 @@ Ltac crunch_matrix := Ltac compound M := match M with - | ?A × ?B => idtac - | ?A .+ ?B => idtac - | ?A † => compound A + | ?A × ?B => idtac + | ?A .+ ?B => idtac + | ?A .⊕ ?B => idtac + | ?A † => compound A + | ?A ⊤ => compound A + | _ .* ?A => compound A end. (* Reduce inner matrices first *) @@ -2113,6 +2768,103 @@ Ltac solve_matrix := assoc_least; (* try to solve complex equalities *) autorewrite with C_db; try lca. +Ltac compute_matrix_getval M := + let lem := constr:(matrix_eq_list2D_to_matrix M + ltac:(auto 100 using WF_Matrix_dim_change with wf_db)) in + lazymatch type of lem with + | _ = @make_WF ?n ?m (list2D_to_matrix ?l) => + let l' := fresh "l'" in let Hl' := fresh "Hl'" in + let _ := match goal with |- _ => + set (l' := l); + autounfold with U_db in l'; + cbn in l'; unfold Cdiv in l'; + let lval := eval unfold l' in l' in + pose proof (lem : _ = @make_WF n m (list2D_to_matrix lval)) as Hl'; + Csimpl_in Hl'; + rewrite Hl' + end in + lazymatch type of Hl' with + | _ = ?B => + let _ := match goal with |- _ => clear l' Hl' end in + constr:(B) + end + end. + +Ltac compute_matrix M := + let rec comp_mat_val M := + match M with + | @Mplus ?n ?m ?A .+ ?B => + let A' := match goal with + | |- _ => let _ := match goal with |- _ => compound A end in + let r := comp_mat_val A in constr:(r) + | |- _ => constr:(A) + end in + let B' := match goal with + | |- _ => let _ := match goal with |- _ => compound B end in + let r := comp_mat_val B in constr:(r) + | |- _ => constr:(B) + end in + let r := compute_matrix_getval (@Mplus n m A' B') in + constr:(r) + | @kron ?a ?b ?c ?d ?A ?B => + let A' := match goal with + | |- _ => let _ := match goal with |- _ => compound A end in + let r := comp_mat_val A in constr:(r) + | |- _ => constr:(A) + end in + let B' := match goal with + | |- _ => let _ := match goal with |- _ => compound B end in + let r := comp_mat_val B in constr:(r) + | |- _ => constr:(B) + end in + let r := compute_matrix_getval (@kron a b c d A' B') in + constr:(r) + | @Mmult ?a ?b ?c ?A ?B => + let A' := match goal with + | |- _ => let _ := match goal with |- _ => compound A end in + let r := comp_mat_val A in constr:(r) + | |- _ => constr:(A) + end in + let B' := match goal with + | |- _ => let _ := match goal with |- _ => compound B end in + let r := comp_mat_val B in constr:(r) + | |- _ => constr:(B) + end in + let r := compute_matrix_getval (@Mmult a b c A' B') in + constr:(r) + | @scale ?a ?b ?A => + let _ := match goal with |- _ => compound A end in + let A' := comp_mat_val A in + let r := compute_matrix_getval (@scale a b A') in + constr:(r) + | ?A => + let r := compute_matrix_getval A in + constr:(r) + end + in let _ := comp_mat_val M in idtac. + +Ltac solve_matrix_fast_with_tacs pretac posttac := + prep_matrix_equivalence; pretac; by_cell_no_intros; posttac. + +Tactic Notation "solve_matrix_fast_with" tactic0(pretac) tactic(posttac) := + solve_matrix_fast_with_tacs pretac posttac. + +Ltac solve_matrix_fast := + solve_matrix_fast_with idtac lca. + + +Create HintDb scalar_move_db. + +#[export] Hint Rewrite + Mscale_kron_dist_l + Mscale_kron_dist_r + Mscale_mult_dist_l + Mscale_mult_dist_r + Mscale_assoc : scalar_move_db. + +#[export] Hint Rewrite <- Mscale_plus_distr_l : scalar_move_db. +#[export] Hint Rewrite <- Mscale_plus_distr_r : scalar_move_db. + (** Gridify **) diff --git a/Measurement.v b/Measurement.v index b92d81a..9aa6b74 100644 --- a/Measurement.v +++ b/Measurement.v @@ -1,4 +1,5 @@ Require Import VectorStates. +Require Import Modulus. (** This file contains predicates for describing the outcomes of measurement. *) @@ -36,18 +37,18 @@ Lemma rewrite_I_as_sum : forall m n, I m = big_sum (fun i => (basis_vector n i) × (basis_vector n i)†) m. Proof. intros. - induction m. - simpl. - unfold I. - prep_matrix_equality. - bdestruct_all; reflexivity. + induction m; [simpl; lma'|]. simpl. rewrite <- IHm by lia. + change (Matrix (S m) (S m)) with (Matrix n n). + apply mat_equiv_eq; [auto_wf| + apply WF_Matrix_dim_change, WF_plus; auto with wf_db zarith |]; + [unfold I; intros ? ? []; Modulus.bdestructΩ'..|]. + intros i j Hi Hj. unfold basis_vector. - solve_matrix. - bdestruct_all; simpl; try lca. - all: destruct m; simpl; try lca. - all: bdestruct_all; lca. + autounfold with U_db. + simpl. + bdestruct_all; lca. Qed. Lemma prob_partial_meas_alt : @@ -59,7 +60,7 @@ Proof. unfold prob_partial_meas. rewrite norm_squared. unfold inner_product, Mmult, adjoint. - rewrite (@big_sum_func_distr C R _ C_is_group _ R_is_group), Nat.mul_1_l. + rewrite Re_big_sum, Nat.mul_1_l. apply big_sum_eq_bounded; intros. unfold probability_of_outcome. assert (H' : forall c, ((Cmod c)^2)%R = fst (c^* * c)). @@ -69,20 +70,18 @@ Proof. simpl; lra. } rewrite H'. apply f_equal. - assert (H'' : forall a b, a = b -> a^* * a = b^* * b). { intros; subst; easy. } - apply H''. + apply (f_equal (fun x => x^* * x)). unfold inner_product, Mmult. apply big_sum_eq_bounded; intros. apply f_equal_gen; auto. apply f_equal. unfold kron, adjoint. rewrite Cconj_mult_distr. - rewrite Nat.div_0_l, Nat.mod_0_l, (Nat.div_small x (2^n)), (Nat.mod_small x); try nia. + rewrite Nat.Div0.div_0_l, Nat.Div0.mod_0_l, + (Nat.div_small x (2^n)), (Nat.mod_small x) by nia. apply f_equal_gen; auto. unfold basis_vector, I. bdestruct_all; try lia; simpl; try lca. - intros. - destruct a; destruct b; easy. Qed. Lemma partial_meas_tensor : diff --git a/Modulus.v b/Modulus.v index 95a3635..b8f63f5 100644 --- a/Modulus.v +++ b/Modulus.v @@ -1,5 +1,1155 @@ +(** Some lemmas about nats, especially about div and mod *) + Require Import Prelim. +(** * Automation extending that in Prelim *) + +Ltac bdestruct_one := + let fail_if_iffy H := + match H with + | context[if _ then _ else _] => fail 1 + | _ => idtac + end + in + match goal with + | |- context [ ?a fail_if_iffy a; fail_if_iffy b; bdestruct (a fail_if_iffy a; fail_if_iffy b; bdestruct (a <=? b) + | |- context [ ?a =? ?b ] => fail_if_iffy a; fail_if_iffy b; bdestruct (a =? b) + | |- context[if ?b then _ else _] + => fail_if_iffy b; destruct b eqn:? + end. + +Ltac bdestructΩ' := + let tryeasylia := try easy; try lia in + repeat (bdestruct_one; subst; tryeasylia); + tryeasylia. + +Ltac replace_bool_lia b0 b1 := + first [ + replace b0 with b1 by (bdestruct b0; lia || (destruct b1 eqn:?; lia)) | + replace b0 with b1 by (bdestruct b1; lia || (destruct b0 eqn:?; lia)) | + replace b0 with b1 by (bdestruct b0; bdestruct b1; lia) + ]. + +Ltac simpl_bools := + repeat (cbn [andb orb negb xorb]; + rewrite ?andb_true_r, ?andb_false_r, ?orb_true_r, ?orb_false_r). + +Ltac simplify_bools_lia_one_free := + let act_T b := ((replace_bool_lia b true || replace_bool_lia b false); simpl) in + let act_F b := ((replace_bool_lia b false || replace_bool_lia b true); simpl) in + match goal with + | |- context[?b && _] => act_F b; rewrite ?andb_true_l, ?andb_false_l + | |- context[_ && ?b] => act_F b; rewrite ?andb_true_r, ?andb_false_r + | |- context[?b || _] => act_T b; rewrite ?orb_true_l, ?orb_false_l + | |- context[_ || ?b] => act_T b; rewrite ?orb_true_r, ?orb_false_r + | |- context[negb ?b] => act_T b; simpl negb + | |- context[if ?b then _ else _] => act_T b + end; simpl_bools. + +Ltac simplify_bools_lia_one_kernel := + let fail_if_iffy H := + match H with + | context [ if _ then _ else _ ] => fail 1 + | _ => idtac + end + in + let fail_if_compound H := + fail_if_iffy H; + match H with + | context [ ?a && ?b ] => fail 1 + | context [ ?a || ?b ] => fail 1 + | _ => idtac + end + in + let act_T b := (fail_if_compound b; + (replace_bool_lia b true || replace_bool_lia b false); simpl) in + let act_F b := (fail_if_compound b; + (replace_bool_lia b false || replace_bool_lia b true); simpl) in + match goal with + | |- context[?b && _] => act_F b; rewrite ?andb_true_l, ?andb_false_l + | |- context[_ && ?b] => act_F b; rewrite ?andb_true_r, ?andb_false_r + | |- context[?b || _] => act_T b; rewrite ?orb_true_l, ?orb_false_l + | |- context[_ || ?b] => act_T b; rewrite ?orb_true_r, ?orb_false_r + | |- context[negb ?b] => act_T b; simpl negb + | |- context[if ?b then _ else _] => act_T b + end; simpl_bools. + +Ltac simplify_bools_lia_many_kernel := + let fail_if_iffy H := + match H with + | context [ if _ then _ else _ ] => fail 1 + | _ => idtac + end + in + let fail_if_compound H := + fail_if_iffy H; + match H with + | context [ ?a && ?b ] => fail 1 + | context [ ?a || ?b ] => fail 1 + | _ => idtac + end + in + let act_T b := (fail_if_compound b; + (replace_bool_lia b true || replace_bool_lia b false); simpl) in + let act_F b := (fail_if_compound b; + (replace_bool_lia b false || replace_bool_lia b true); simpl) in + multimatch goal with + | |- context[?b && _] => act_F b; rewrite ?andb_true_l, ?andb_false_l + | |- context[_ && ?b] => act_F b; rewrite ?andb_true_r, ?andb_false_r + | |- context[?b || _] => act_T b; rewrite ?orb_true_l, ?orb_false_l + | |- context[_ || ?b] => act_T b; rewrite ?orb_true_r, ?orb_false_r + | |- context[negb ?b] => act_T b; simpl negb + | |- context[if ?b then _ else _] => act_T b + end; simpl_bools. + +Ltac simplify_bools_lia_one := + simplify_bools_lia_one_kernel || simplify_bools_lia_one_free. + +Ltac simplify_bools_lia := + repeat simplify_bools_lia_one. + +Ltac bdestruct_one_old := + let fail_if_iffy H := + match H with + | context [ if _ then _ else _ ] => fail 1 + | _ => idtac + end + in + match goal with + | |- context [ ?a + fail_if_iffy a; fail_if_iffy b; bdestruct (a + fail_if_iffy a; fail_if_iffy b; bdestruct (a <=? b) + | |- context [ ?a =? ?b ] => + fail_if_iffy a; fail_if_iffy b; bdestruct (a =? b) + | |- context [ if ?b then _ else _ ] => fail_if_iffy b; destruct b eqn:? + end. + +Ltac bdestruct_one_new := + let fail_if_iffy H := + match H with + | context [ if _ then _ else _ ] => fail 1 + | _ => idtac + end + in + let fail_if_booley H := + fail_if_iffy H; + match H with + | context [ ?a fail 1 + | context [ ?a <=? ?b ] => fail 1 + | context [ ?a =? ?b ] => fail 1 + | context [ ?a && ?b ] => fail 1 + | context [ ?a || ?b ] => fail 1 + | context [ negb ?a ] => fail 1 + | context [ xorb ?a ?b ] => fail 1 + | _ => idtac + end + in + let rec destruct_kernel H := + match H with + | context [ if ?b then _ else _ ] => destruct_kernel b + | context [ ?a + tryif fail_if_booley a then + (tryif fail_if_booley b then bdestruct (a + tryif fail_if_booley a then + (tryif fail_if_booley b then bdestruct (a <=? b) + else destruct_kernel b) else (destruct_kernel a) + | context [ ?a =? ?b ] => + tryif fail_if_booley a then + (tryif fail_if_booley b then bdestruct (a =? b); try subst + else destruct_kernel b) else (destruct_kernel a) + | context [ ?a && ?b ] => + destruct_kernel a || destruct_kernel b + | context [ ?a || ?b ] => + destruct_kernel a || destruct_kernel b + | context [ xorb ?a ?b ] => + destruct_kernel a || destruct_kernel b + | context [ negb ?a ] => + destruct_kernel a + | _ => idtac + end + in + simpl_bools; + match goal with + | |- context [ ?a =? ?b ] => + fail_if_iffy a; fail_if_iffy b; bdestruct (a =? b); try subst + | |- context [ ?a + fail_if_iffy a; fail_if_iffy b; bdestruct (a + fail_if_iffy a; fail_if_iffy b; bdestruct (a <=? b) + | |- context [ if ?b then _ else _ ] => fail_if_iffy b; destruct b eqn:? + end; + simpl_bools. + +Ltac bdestruct_one' := bdestruct_one_new || bdestruct_one_old. + +Ltac bdestructΩ'_with tac := + tac; + repeat (bdestruct_one'; subst; simpl_bools; tac); + tac. + +(* Ltac bdestructΩ'simp := + bdestructΩ'_with ltac:(try easy + lca + lia). *) + + + + +Lemma pow2_nonzero n : 2 ^ n <> 0. +Proof. + apply Nat.pow_nonzero; lia. +Qed. + +Ltac show_term_nonzero term := + match term with + | 2 ^ ?a => exact (pow2_nonzero a) + | ?a ^ ?b => exact (Nat.pow_nonzero a b ltac:(show_term_nonzero a)) + | ?a * ?b => + (assert (a <> 0) by (show_term_nonzero a); + assert (b <> 0) by (show_term_nonzero b); + lia) + | ?a + ?b => + ((assert (a <> 0) by (show_term_nonzero a) || + assert (b <> 0) by (show_term_nonzero b)); + lia) + | _ => lia + | _ => nia + end. + +Ltac show_nonzero := + match goal with + | |- ?t <> 0 => show_term_nonzero t + | |- 0 <> ?t => symmetry; show_term_nonzero t + | |- 0 < ?t => assert (t <> 0) by (show_term_nonzero t); lia + | |- ?t > 0 => assert (t <> 0) by (show_term_nonzero t); lia + | _ => lia + end. + +Ltac get_div_by_pow_2 t pwr := + match t with + | 2 ^ pwr => constr:(1) + | 2 ^ pwr * ?a => constr:(a) + | ?a * 2 ^ pwr => constr:(a) + | ?a * ?b => let ra := get_div_by_pow_2 a pwr in constr:(ra * b) + | ?a * ?b => let rb := get_div_by_pow_2 b pwr in constr:(a * rb) + | 2 ^ (?a + ?b) => + let val := constr:(2 ^ a * 2 ^ b) in + get_div_by_pow_2 val pwr + | ?a + ?b => + let ra := get_div_by_pow_2 a pwr in + let rb := get_div_by_pow_2 b pwr in + constr:(ra + rb) + | ?a - 1 => + let ra := get_div_by_pow_2 a pwr in + constr:(ra - 1) + end. + +Lemma div_mul_l a b : a <> 0 -> + (a * b) / a = b. +Proof. + rewrite Nat.mul_comm; + apply Nat.div_mul. +Qed. + + +Ltac show_div_by_pow2_ge t pwr := + (* Shows t / 2 ^ pwr <= get_div_by_pwr t pwr *) + match t with + | 2 ^ pwr => (* constr:(1) *) + rewrite (Nat.div_same (2^pwr) (pow2_nonzero pwr)); + apply Nat.le_refl + | 2 ^ pwr * ?a => (* constr:(a) *) + rewrite (div_mul_l (2^pwr) a (pow2_nonzero pwr)); + apply Nat.le_refl + | ?a * 2 ^ pwr => (* constr:(a) *) + rewrite (Nat.div_mul a (2^pwr) (pow2_nonzero pwr)); + apply Nat.le_refl + | ?a * (?b * ?c) => + let rval := constr:(a * b * c) in + show_div_by_pow2_ge rval pwr + | ?a * ?b => (* b is not right, so... *) + let rval := constr:(b * a) in + show_div_by_pow2_ge rval pwr + | ?a + ?b => + let ra := get_div_by_pow_2 a pwr in + let rb := get_div_by_pow_2 b pwr in + constr:(ra + rb) + | ?a - 1 => + fail 1 "Case not supported" + | 2 ^ (?a + ?b) => + let val := constr:(2 ^ a * 2 ^ b) in + rewrite (Nat.pow_add_r 2 a b); + show_div_by_pow2_ge val pwr + + end. + + +Ltac get_div_by t val := + match t with + | val => constr:(1) + | val * ?a => constr:(a) + | ?a * val => constr:(a) + | ?a * ?b => let ra := get_div_by a val in constr:(ra * b) + | ?a * ?b => let rb := get_div_by b val in constr:(a * rb) + | 2 ^ (?a + ?b) => + let val' := constr:(2 ^ a * 2 ^ b) in + get_div_by val' val + | ?a + ?b => + let ra := get_div_by a val in + let rb := get_div_by b val in + constr:(ra + rb) + | ?a - 1 => + let ra := get_div_by a val in + constr:(ra - 1) + end. + +Ltac show_div_by_ge t val := + (* Shows t / val <= get_div_by t val *) + match t with + | val => (* constr:(1) *) + rewrite (Nat.div_same val ltac:(show_term_nonzero val)); + apply Nat.le_refl + | val * ?a => (* constr:(a) *) + rewrite (div_mul_l val a ltac:(show_term_nonzero val)); + apply Nat.le_refl + | ?a * val => (* constr:(a) *) + rewrite (Nat.div_mul a val ltac:(show_term_nonzero val)); + apply Nat.le_refl + | ?a * (?b * ?c) => + let rval := constr:(a * b * c) in + show_div_by_ge rval val + | ?a * ?b => (* b is not right, so... *) + let rval := constr:(b * a) in + show_div_by_ge rval val + | ?a + ?b => + let ra := get_div_by a val in + let rb := get_div_by b val in + constr:(ra + rb) + | ?a - 1 => + nia || + fail 1 "Case not supported" + end. + +Ltac get_strict_upper_bound term := + match term with + | ?k mod 0 => let r := get_strict_upper_bound k in constr:(r) + | ?k mod (2 ^ ?a) => constr:(Nat.pow 2 a) + | ?k mod (?a ^ ?b) => constr:(Nat.pow a b) + | ?k mod ?a => + let _ := match goal with |- _ => assert (H: a <> 0) by show_nonzero end in + constr:(a) + | ?k mod ?a => + let _ := match goal with |- _ => assert (H: a = 0) by lia end in + constr:(k + 1) + + | 2 ^ ?a * ?t => let r := get_strict_upper_bound t in + constr:(Nat.mul (Nat.pow 2 a) r) + | ?t * 2 ^ ?a => let r := get_strict_upper_bound t in + constr:(Nat.mul r (Nat.pow 2 a)) + | ?a ^ ?b => constr:(Nat.pow a b + 1) + + | ?a + ?b => + let ra := get_strict_upper_bound a in + let rb := get_strict_upper_bound b in + constr:(ra + rb + 1) + | ?a * ?b => + let ra := get_strict_upper_bound a in + let rb := get_strict_upper_bound b in + constr:(ra * rb + 1) + | ?a / (?b * (?c * ?d)) => let rval := constr:(a / (b * c * d)) in + let r := get_strict_upper_bound rval in constr:(r) + | ?a / (?b * ?c) => let rval := constr:(a / b / c) in + let r := get_strict_upper_bound rval in constr:(r) + | ?a / (2 ^ ?b) => + let ra := get_strict_upper_bound a in + let rr := get_div_by_pow_2 ra b in constr:(rr) + + | ?t => match goal with + | H : t < ?a |- _ => constr:(a) + | H : t <= ?a |- _ => constr:(a + 1) + | _ => constr:(t + 1) + end + end. + +Ltac get_upper_bound term := + match term with + | ?k mod 0 => let r := get_upper_bound k in constr:(r) + | ?k mod (2 ^ ?a) => constr:(Nat.sub (Nat.pow 2 a) 1) + | ?k mod (?a ^ ?b) => constr:(Nat.sub (Nat.pow a b) 1) + | ?k mod ?a => + let H := fresh in + let _ := match goal with |- _ => + assert (H: a <> 0) by show_nonzero; clear H end in + constr:(a - 1) + | ?k mod ?a => + let H := fresh in + let _ := match goal with |- _ => + assert (H: a = 0) by lia; clear H end in + let rk := get_upper_bound k in + constr:(rk) + + | 2 ^ ?a * ?t => let r := get_upper_bound t in + constr:(Nat.mul (Nat.pow 2 a) r) + | ?t * 2 ^ ?a => let r := get_upper_bound t in + constr:(Nat.mul r (Nat.pow 2 a)) + | ?a ^ ?b => constr:(Nat.pow a b) + + | ?a + ?b => + let ra := get_upper_bound a in + let rb := get_upper_bound b in + constr:(ra + rb) + | ?a * ?b => + let ra := get_upper_bound a in + let rb := get_upper_bound b in + constr:(ra * rb) + | ?a / (?b * (?c * ?d)) => let rval := constr:(a / (b * c * d)) in + let r := get_upper_bound rval in constr:(r) + | ?a / (?b * ?c) => let rval := constr:(a / b / c) in + let r := get_upper_bound rval in constr:(r) + | ?a / (2 ^ ?b) => + let ra := get_strict_upper_bound a in + let rr := get_div_by_pow_2 ra b in constr:(rr - 1) + + | ?a / ?b => + let ra := get_strict_upper_bound a in + let rr := get_div_by ra b in constr:(rr - 1) + + | ?t => match goal with + | H : t < ?a |- _ => constr:(a - 1) + | H : t <= ?a |- _ => constr:(a) + | _ => t + end + end. + +Lemma mul_ge_l_of_nonzero p q : q <> 0 -> + p <= p * q. +Proof. + nia. +Qed. + +Lemma mul_ge_r_of_nonzero p q : p <> 0 -> + q <= p * q. +Proof. + nia. +Qed. + +Ltac show_pow2_le := + rewrite ?Nat.pow_add_r, + ?Nat.mul_add_distr_r, ?Nat.mul_add_distr_l, + ?Nat.mul_sub_distr_r, ?Nat.mul_sub_distr_l, + ?Nat.mul_1_r, ?Nat.mul_1_l; + repeat match goal with + |- context [2 ^ ?a] => + tryif assert (2 ^ a <> 0) by assumption + then fail + else pose proof (pow2_nonzero a) + end; + nia || ( + repeat match goal with + | |- context [?p * ?q] => + tryif assert (p <> 0) by assumption + then + (tryif assert (q <> 0) by assumption + then fail + else assert (q <> 0) by nia) + else assert (p <> 0) by nia; + (tryif assert (q <> 0) by assumption + then idtac else assert (q <> 0) by nia) + end; + repeat match goal with + | |- context [?p * ?q] => + tryif assert (p <= p * q) by assumption + then + (tryif assert (q <= p * q) by assumption + then fail + else pose proof (mul_ge_r_of_nonzero p q ltac:(assumption))) + else pose proof (mul_ge_l_of_nonzero p q ltac:(assumption)); + (tryif assert (q <= p * q) by assumption + then idtac + else pose proof (mul_ge_r_of_nonzero p q ltac:(assumption))) + end; + nia). + + +Lemma lt_of_le_sub_1 a b : + b <> 0 -> a <= b - 1 -> a < b. +Proof. lia. Qed. + +Lemma le_sub_1_of_lt a b : + a < b -> a <= b - 1. +Proof. lia. Qed. + +(* FIXME: TODO: Remove in favor of Nat.Div0.div_le_mono when we upgrade past Coq ~8.16*) +Lemma div0_div_le_mono : forall a b c : nat, a <= b -> a / c <= b / c. +Proof. + intros a b []; [easy|]. + apply Nat.div_le_mono; easy. +Qed. + +Lemma div0_div_lt_upper_bound : forall a b c : nat, a < b * c -> + a / b < c. +Proof. + intros a b c H; apply Nat.div_lt_upper_bound; lia. +Qed. + +Lemma div0_div_div : forall a b c, a / b / c = a / (b * c). +Proof. + intros a [] []; [rewrite ?Nat.mul_0_r; easy..|]. + now apply Nat.div_div. +Qed. + +Lemma nat_mod_0_r : forall a, a mod 0 = a. +Proof. easy. Qed. + +Lemma div0_mod_0_l : forall a, 0 mod a = 0. +Proof. + intros []; [easy|]; + now apply Nat.mod_0_l. +Qed. + +Lemma div0_mod_add : forall a b c, (a + b * c) mod c = a mod c. +Proof. + intros a b []; [f_equal; lia|]; + now apply Nat.mod_add. +Qed. + +Lemma div0_mod_mul_r : forall a b c, + a mod (b * c) = a mod b + b * ((a / b) mod c). +Proof. + intros a [] []; rewrite ?Nat.mul_0_r, ?Nat.mul_0_l, + ?nat_mod_0_r; [lia..| pose proof (Nat.div_mod_eq a (S n)); lia |]. + now apply Nat.mod_mul_r. +Qed. + +Lemma div0_mod_mod : forall a n, (a mod n) mod n = a mod n. +Proof. + intros a []; [easy|]; now apply Nat.mod_mod. +Qed. + +Lemma div0_mod_mul : forall a b, (a * b) mod b = 0. +Proof. + intros a []; [cbn;lia|]; + now apply Nat.mod_mul. +Qed. + +Lemma div0_add_mod_idemp_l : forall a b n : nat, + (a mod n + b) mod n = (a + b) mod n. +Proof. + intros a b []; [easy|]; now apply Nat.add_mod_idemp_l. +Qed. + +Lemma div0_add_mod : forall a b n, + (a + b) mod n = (a mod n + b mod n) mod n. +Proof. + intros a b []; [easy|]; + now apply Nat.add_mod. +Qed. + +Lemma div0_mod_same : forall n, + n mod n = 0. +Proof. + intros []; [easy|]; now apply Nat.mod_same. +Qed. + +Lemma div0_div_0_l : forall n, 0 / n = 0. +Proof. intros []; easy. Qed. + +Notation "Nat.Div0.div_le_mono" := div0_div_le_mono. +Notation "Nat.Div0.div_lt_upper_bound" := div0_div_lt_upper_bound. +Notation "Nat.Div0.div_div" := div0_div_div. +Notation "Nat.mod_0_r" := nat_mod_0_r. +Notation "Nat.Div0.div_0_l" := div0_div_0_l. +Notation "Nat.Div0.mod_0_l" := div0_mod_0_l. +Notation "Nat.Div0.mod_add" := div0_mod_add. +Notation "Nat.Div0.mod_same" := div0_mod_same. +Notation "Nat.Div0.mod_mul_r" := div0_mod_mul_r. +Notation "Nat.Div0.mod_mod" := div0_mod_mod. +Notation "Nat.Div0.mod_mul" := div0_mod_mul. +Notation "Nat.Div0.add_mod" := div0_add_mod. +Notation "Nat.Div0.add_mod_idemp_l" := div0_add_mod_idemp_l. + + +Ltac show_le_upper_bound term := + lazymatch term with + | ?k mod 0 => + rewrite (Nat.mod_0_r k); + show_le_upper_bound k + | ?k mod (2 ^ ?a) => + exact (le_sub_1_of_lt (k mod (2^a)) (2^a) + (Nat.mod_upper_bound k (2^a) (pow2_nonzero a))) + | ?k mod (?a ^ ?b) => + exact (le_sub_1_of_lt (k mod (2^a)) (a^b) + (Nat.mod_upper_bound k (a^b) + (Nat.pow_nonzero a b ltac:(show_term_nonzero a)))) + | ?k mod ?a => + let H := fresh in + let _ := match goal with |- _ => + assert (H: a <> 0) by show_nonzero end in + exact (le_sub_1_of_lt _ _ (Nat.mod_upper_bound k a H)) + | ?k mod ?a => + let H := fresh in + let _ := match goal with |- _ => + assert (H: a = 0) by lia end in + rewrite H; + show_le_upper_bound k + + | 2 ^ ?a * ?t => let r := get_upper_bound t in + apply (Nat.mul_le_mono_l t _ (2^a)); + show_le_upper_bound t + | ?t * 2 ^ ?a => let r := get_upper_bound t in + apply (Nat.mul_le_mono_r t _ (2^a)); + show_le_upper_bound t + | ?a ^ ?b => + apply Nat.le_refl + + | ?a + ?b => + apply Nat.add_le_mono; + [ + (* match goal with |- ?G => idtac G "should be about" a end; *) + show_le_upper_bound a | + show_le_upper_bound b] + | ?a * ?b => + apply Nat.mul_le_mono; + [ + (* match goal with |- ?G => idtac G "should be about" a end; *) + show_le_upper_bound a | + show_le_upper_bound b] + | ?a / (?b * (?c * ?d)) => + let H := fresh in + pose proof (f_equal (fun x => a / x) (Nat.mul_assoc b c d) : + a / (b * (c * d)) = a / (b * c * d)) as H; + rewrite H; + clear H; + let rval := constr:(a / (b * c * d)) in + show_le_upper_bound rval + | ?a / (?b * ?c) => + let H := fresh in + pose proof (eq_sym (Nat.Div0.div_div a b c) : + a / (b * c) = a / b / c) as H; + rewrite H; + clear H; + let rval := constr:(a / b / c) in + show_le_upper_bound rval + | ?a / (2 ^ ?b) => + let ra := get_upper_bound a in + apply (Nat.le_trans (a / (2^b)) (ra / (2^b)) _); + [apply Nat.Div0.div_le_mono; + show_le_upper_bound a | + tryif show_div_by_pow2_ge ra b then idtac + else + match goal with + | |- (?val - 1) / 2 ^ ?pwr <= ?rhs - 1 => + apply le_sub_1_of_lt, Nat.Div0.div_lt_upper_bound; + tryif nia || show_pow2_le then idtac + else fail 20 "nia failed" "on (" val "- 1) / 2 ^" pwr "<=" rhs "- 1" + | |- ?G => + tryif nia then idtac else + fail 40 "show div failed for" a "/ (2^" b "), ra =" ra + "; full goal:" G + end] + | ?a / ?b => + let ra := get_upper_bound a in + apply (Nat.le_trans (a / b) (ra / b) _); + [apply Nat.Div0.div_le_mono; + show_le_upper_bound a | + tryif show_div_by_ge ra b then idtac + else + match goal with + | |- (?val - 1) / ?den <= ?rhs - 1 => + apply le_sub_1_of_lt, Nat.Div0.div_lt_upper_bound; + tryif nia || show_pow2_le then idtac + else fail 20 "nia failed" "on (" val "- 1) / " den "<=" rhs "- 1" + | |- ?G => + tryif nia then idtac else + fail 40 "show div failed for" a "/ (" b "), ra =" ra + "; full goal:" G + end] + | ?t => match goal with + | _ => nia + end + end. + +Create HintDb show_moddy_lt_db. + +Ltac show_moddy_lt := + try trivial with show_moddy_lt_db; + lazymatch goal with + | |- Nat.b2n ?b < ?a => + apply (Nat.le_lt_trans (Nat.b2n b) (2^1) a); + [destruct b; simpl; lia | show_pow2_le] + | |- ?a < ?b => + let r := get_upper_bound a in + apply (Nat.le_lt_trans a r b); + [show_le_upper_bound a | show_pow2_le] + | |- ?a <= ?b => (* Likely not to work *) + let r := get_upper_bound a in + apply (Nat.le_trans a r b); + [show_le_upper_bound a | show_pow2_le] + | |- ?a > ?b => + change (b < a); show_moddy_lt + | |- ?a >= ?b => + change (b <= a); show_moddy_lt + | |- (?a + apply (proj2 (Nat.ltb_lt a b)); + show_moddy_lt + | |- true = (?a + symmetry; + apply (proj2 (Nat.ltb_lt a b)); + show_moddy_lt + | |- (?a <=? ?b) = false => + apply (proj2 (Nat.leb_gt a b)); + show_moddy_lt + | |- false = (?a <=? ?b) => + symmetry; + apply (proj2 (Nat.leb_gt a b)); + show_moddy_lt + end. + +Ltac try_show_moddy_lt := + try trivial with show_moddy_lt_db; + lazymatch goal with + | |- Nat.b2n ?b < ?a => + apply (Nat.le_lt_trans (Nat.b2n b) (2^1) a); + [destruct b; simpl; lia | try show_pow2_le] + | |- ?a < ?b => + let r := get_upper_bound a in + apply (Nat.le_lt_trans a r b); + [try show_le_upper_bound a | try show_pow2_le] + | |- ?a <= ?b => (* Likely not to work *) + let r := get_upper_bound a in + apply (Nat.le_trans a r b); + [try show_le_upper_bound a | try show_pow2_le] + | |- ?a > ?b => + change (b < a); try_show_moddy_lt + | |- ?a >= ?b => + change (b <= a); try_show_moddy_lt + | |- (?a + apply (proj2 (Nat.ltb_lt a b)); + try_show_moddy_lt + | |- true = (?a + symmetry; + apply (proj2 (Nat.ltb_lt a b)); + try_show_moddy_lt + | |- (?a <=? ?b) = false => + apply (proj2 (Nat.leb_gt a b)); + try_show_moddy_lt + | |- false = (?a <=? ?b) => + symmetry; + apply (proj2 (Nat.leb_gt a b)); + try_show_moddy_lt + end. + +Ltac replace_bool_moddy_lia b0 b1 := + first + [ replace b0 with b1 + by (show_moddy_lt || bdestruct b0; show_moddy_lt + lia + || (destruct b1 eqn:?; lia)) + | replace b0 with b1 + by (bdestruct b1; lia || (destruct b0 eqn:?; lia)) + | replace b0 with b1 + by (bdestruct b0; bdestruct b1; lia) ]. + +Ltac simpl_bools_nosimpl := + repeat (rewrite ?andb_true_r, ?andb_false_r, ?orb_true_r, ?orb_false_r, + ?andb_true_l, ?andb_false_l, ?orb_true_l, ?orb_false_l). + +Ltac simplify_bools_moddy_lia_one_kernel := + let fail_if_iffy H := + match H with + | context [ if _ then _ else _ ] => fail 1 + | _ => idtac + end + in + let fail_if_compound H := + fail_if_iffy H; + match H with + | context [ ?a && ?b ] => fail 1 + | context [ ?a || ?b ] => fail 1 + | _ => idtac + end + in + let act_T b := (fail_if_compound b; + (replace_bool_moddy_lia b true + || replace_bool_moddy_lia b false); simpl) in + let act_F b := (fail_if_compound b; + (replace_bool_moddy_lia b false + || replace_bool_moddy_lia b true); simpl) in + match goal with + | |- context[?b && _] => act_F b; rewrite ?andb_true_l, ?andb_false_l + | |- context[_ && ?b] => act_F b; rewrite ?andb_true_r, ?andb_false_r + | |- context[?b || _] => act_T b; rewrite ?orb_true_l, ?orb_false_l + | |- context[_ || ?b] => act_T b; rewrite ?orb_true_r, ?orb_false_r + | |- context[negb ?b] => act_T b; cbn [negb] + | |- context[if ?b then _ else _] => act_T b + end; simpl_bools_nosimpl. + +(** * Some general nat facts *) + +Section nat_lemmas. + +Import Nat. + +Local Open Scope nat. + +Lemma add_sub' n m : m + n - m = n. +Proof. + lia. +Qed. + +Lemma add_leb_mono_l n m d : + (n + m <=? n + d) = (m <=? d). +Proof. + bdestructΩ'. +Qed. + +Lemma add_ltb_mono_l n m d : + (n + m m <= d. +Proof. lia. Qed. + +Lemma add_lt_cancel_l_iff n m d : + n + m < n + d <-> m < d. +Proof. lia. Qed. + +Lemma add_ge_cancel_l_iff n m d : + n + m >= n + d <-> m >= d. +Proof. lia. Qed. + +Lemma add_gt_cancel_l_iff n m d : + n + m > n + d <-> m > d. +Proof. lia. Qed. + +Lemma sub_lt_iff n m p (Hp : 0 <> p) : + n - m < p <-> n < m + p. +Proof. + split; lia. +Qed. + +Lemma sub_eq_iff {a b m} : b <= a -> + a - b = m <-> a = b + m. +Proof. + lia. +Qed. + +Lemma n_le_pow_2_n (n : nat) : n <= 2 ^ n. +Proof. + induction n; simpl; [lia|]. + pose proof (pow_positive 2 n). + lia. +Qed. + +Lemma div_mul_not_exact a b : b <> 0 -> + (a / b) * b = a - (a mod b). +Proof. + intros Hb. + rewrite (Nat.div_mod a b Hb) at 1 2. + rewrite Nat.add_sub. + rewrite (Nat.mul_comm b (a/b)), Nat.add_comm, Nat.div_add by easy. + rewrite Nat.div_small by (apply Nat.mod_upper_bound; easy). + easy. +Qed. + +Lemma diff_divs_lower_bound a b k n : + (a < n -> b < n -> a / k <> b / k -> k < n)%nat. +Proof. + intros Ha Hb Hne. + bdestructΩ (k + (x - y * z) mod z = x mod z. +Proof. + intros. + replace (x mod z) with ((x - y * z + y * z) mod z) by (f_equal; lia). + now rewrite Nat.Div0.mod_add. +Qed. + +Lemma mod_product : forall x y z, x mod (y * z) mod z = x mod z. +Proof. + intros. + rewrite Nat.mul_comm, Nat.Div0.mod_mul_r, Nat.mul_comm. + now rewrite Nat.Div0.mod_add, Nat.Div0.mod_mod. +Qed. + +Lemma mod_add_l a b c : (a * b + c) mod b = c mod b. +Proof. + rewrite Nat.add_comm. + apply Nat.Div0.mod_add. +Qed. + +Lemma div_eq a b : a / b = (a - a mod b) / b. +Proof. + rewrite (Nat.div_mod_eq a b) at 2. + rewrite Nat.add_sub. + bdestruct (b =? 0). + - now subst. + - now rewrite Nat.mul_comm, Nat.div_mul by easy. +Qed. + +Lemma sub_mod_le n m : m <= n -> + (n - m mod n) mod n = (n - m) mod n. +Proof. + intros Hm. + bdestruct (m =? n). + - subst. + now rewrite Nat.Div0.mod_same, Nat.sub_0_r, Nat.sub_diag, + Nat.Div0.mod_same, Nat.Div0.mod_0_l. + - now rewrite (Nat.mod_small m) by lia. +Qed. + +Lemma mod_mul_sub_le a b c : c <> 0 -> a <= b * c -> + (b * c - a) mod c = + c * Nat.b2n (¬ a mod c =? 0) - a mod c. +Proof. + intros Hc Ha. + bdestruct (a =? b * c). + - subst. + rewrite Nat.sub_diag, Nat.Div0.mod_mul, Nat.Div0.mod_0_l. + cbn; lia. + - rewrite (Nat.div_mod_eq a c) at 1. + assert (a < b * c) by lia. + assert (a / c < b) by (apply Nat.Div0.div_lt_upper_bound; lia). + assert (a mod c < c) by show_moddy_lt. + replace (b * c - (c * (a / c) + a mod c)) with + ((b - a / c - 1) * c + (c - a mod c)) by nia. + rewrite mod_add_l. + bdestruct (a mod c =? 0). + + replace -> (a mod c). + rewrite Nat.sub_0_r, Nat.Div0.mod_same. + cbn; lia. + + rewrite Nat.mod_small by lia. + cbn; lia. +Qed. + +Lemma div_sub a b c : c <> 0 -> + (b * c - a) / c = b - a / c - Nat.b2n (¬ a mod c =? 0). +Proof. + intros Hc. + bdestruct (a ; easy|]. + intros. + rewrite (Nat.div_mod_eq i a), (Nat.div_mod_eq j a). + lia. +Qed. + +Lemma eqb_comb_iff_div_mod_eqb a i x y (Hy : y < a) : + i =? x * a + y = + (i mod a =? y) && (i / a =? x). +Proof. + rewrite (eqb_iff_div_mod_eqb a). + rewrite mod_add_l, Nat.div_add_l, + (Nat.mod_small y), (Nat.div_small y) by lia. + now rewrite Nat.add_0_r. +Qed. + +Lemma eqb_div_mod_pow_2_iff a i j k l : + i mod 2 ^ a + 2 ^ a * j =? k mod 2 ^ a + 2 ^ a * l = + ((i mod 2 ^ a =? k mod 2 ^ a) && + (j =? l)). +Proof. + apply eq_iff_eq_true. + rewrite andb_true_iff, !Nat.eqb_eq. + split; try lia. + rewrite 2!(Nat.mul_comm (2^a)). + intros H. + generalize (f_equal (fun x => x mod 2^a) H). + rewrite 2!Nat.Div0.mod_add, !Nat.Div0.mod_mod. + intros; split; [easy|]. + generalize (f_equal (fun x => x / 2^a) H). + now rewrite 2!Nat.div_add, !Nat.div_small by + (try apply Nat.mod_upper_bound; try apply pow_nonzero; lia). +Qed. + +Lemma succ_even_lt_even a b : Nat.even a = true -> + Nat.even b = true -> + a < b -> S a < b. +Proof. + intros Ha Hb Hab. + enough (S a <> b) by lia. + intros Hf. + apply (f_equal Nat.even) in Hf. + rewrite Nat.even_succ in Hf. + rewrite <- Nat.negb_even in Hf. + rewrite Ha, Hb in Hf. + easy. +Qed. + +Lemma succ_odd_lt_odd a b : Nat.odd a = true -> + Nat.odd b = true -> + a < b -> S a < b. +Proof. + intros Ha Hb Hab. + enough (S a <> b) by lia. + intros Hf. + apply (f_equal Nat.even) in Hf. + rewrite Nat.even_succ in Hf. + rewrite <- Nat.negb_odd in Hf. + rewrite Ha, Hb in Hf. + easy. +Qed. + +Lemma even_add_same n : Nat.even (n + n) = true. +Proof. + now rewrite Nat.even_add, eqb_reflx. +Qed. + +Lemma even_succ_false n : + Nat.even (S n) = false <-> Nat.even n = true. +Proof. + rewrite Nat.even_succ, <- Nat.negb_even. + now destruct (Nat.even n). +Qed. + +Lemma even_succ_add_same n : Nat.even (S (n + n)) = false. +Proof. + now rewrite even_succ_false, even_add_same. +Qed. + +Lemma odd_succ_false n : + Nat.odd (S n) = false <-> Nat.odd n = true. +Proof. + rewrite Nat.odd_succ, <- Nat.negb_odd. + now destruct (Nat.odd n). +Qed. + +Lemma even_le_even_of_le_succ m n + (Hm : Nat.even m = true) (Hn : Nat.even n = true) : + (n <= S m -> n <= m)%nat. +Proof. + intros Hnm. + bdestructΩ (n =? S m). + replace -> n in Hn. + rewrite Nat.even_succ, <- Nat.negb_even in Hn. + now rewrite Hm in Hn. +Qed. + +Lemma even_eqb n : Nat.even n = (n mod 2 =? 0). +Proof. + rewrite (Nat.div_mod_eq n 2) at 1. + rewrite Nat.even_add, Nat.even_mul. + cbn [Nat.even orb]. + pose proof (Nat.mod_upper_bound n 2 ltac:(lia)). + now destruct ((ltac:(lia) : n mod 2 = O \/ n mod 2 = 1%nat)) as + [-> | ->]. +Qed. + +Lemma mod_2_succ n : (S n) mod 2 = 1 - (n mod 2). +Proof. + pose proof (Nat.mod_upper_bound (S n) 2 ltac:(lia)). + pose proof (Nat.mod_upper_bound n 2 ltac:(lia)). + enough (~ (S n mod 2 = 0) <-> n mod 2 = 0) by lia. + rewrite <- Nat.eqb_neq, <- Nat.eqb_eq. + rewrite <- 2!even_eqb. + apply even_succ_false. +Qed. + +Lemma double_add n m : n + m + (n + m) = n + n + (m + m). +Proof. + lia. +Qed. + +Lemma sub_leb_eq n m p : + n - m <=? p = (n <=? m + p). +Proof. + bdestructΩ'. +Qed. + +Lemma sub_ltb_eq_nonzero n m p : p <> 0 -> + n - m + n - m l < n -> k <> l -> + 2^k + 2^l < 2 ^ n. +Proof. + intros. + bdestruct (2^k + 2^l + n - S (n - S k) = k. +Proof. + lia. +Qed. + + Lemma div_mod_inj {a b} (c :nat) : c > 0 -> (a mod c) = (b mod c) /\ (a / c) = (b / c) -> a = b. Proof. @@ -12,63 +1162,401 @@ Proof. Qed. Lemma mod_add_n_r : forall m n, - (m + n) mod n = m mod n. + (m + n) mod n = m mod n. Proof. - intros m n. - replace (m + n)%nat with (m + 1 * n)%nat by lia. - destruct n. - - cbn; easy. - - rewrite Nat.mod_add; - lia. + intros m n. + replace (m + n)%nat with (m + 1 * n)%nat by lia. + destruct n. + - cbn; easy. + - now rewrite Nat.Div0.mod_add. Qed. Lemma mod_eq_sub : forall m n, - m mod n = (m - n * (m / n))%nat. + m mod n = (m - n * (m / n))%nat. Proof. - intros m n. - destruct n. - - cbn; lia. - - assert (H: (S n <> 0)%nat) by easy. - pose proof (Nat.div_mod m (S n) H) as Heq. - lia. + intros m n. + destruct n. + - cbn; lia. + - pose proof (Nat.div_mod m (S n)). + lia. Qed. Lemma mod_of_scale : forall m n q, - (n * q <= m < n * S q)%nat -> m mod n = (m - q * n)%nat. + (n * q <= m < n * S q)%nat -> m mod n = (m - q * n)%nat. Proof. - intros m n q [Hmq HmSq]. - rewrite mod_eq_sub. - replace (m/n)%nat with q; [lia|]. - apply Nat.le_antisymm. - - apply Nat.div_le_lower_bound; lia. - - epose proof (Nat.div_lt_upper_bound m n (S q) _ _). - lia. - Unshelve. - all: lia. + intros m n q [Hmq HmSq]. + rewrite mod_eq_sub. + replace (m/n)%nat with q; [lia|]. + apply Nat.le_antisymm. + - apply Nat.div_le_lower_bound; lia. + - pose proof (Nat.Div0.div_lt_upper_bound m n (S q)). + lia. Qed. Lemma mod_n_to_2n : forall m n, - (n <= m < 2 * n)%nat -> m mod n = (m - n)%nat. + (n <= m < 2 * n)%nat -> m mod n = (m - n)%nat. Proof. - intros. - epose proof (mod_of_scale m n 1 _). - lia. - Unshelve. - lia. + intros. + pose proof (mod_of_scale m n 1). + lia. Qed. Lemma mod_n_to_n_plus_n : forall m n, - (n <= m < n + n)%nat -> m mod n = (m - n)%nat. + (n <= m < n + n)%nat -> m mod n = (m - n)%nat. Proof. - intros. - apply mod_n_to_2n; lia. + intros. + apply mod_n_to_2n; lia. Qed. +(** Lemmas about Nat.testbit *) + +Lemma testbit_add_pow2_small (i j s n : nat) (Hs : s < n) : + Nat.testbit (i + 2^n * j) s = Nat.testbit i s. +Proof. + rewrite 2!Nat.testbit_eqb. + replace n with (s + (n - s)) by lia. + rewrite Nat.pow_add_r, <- Nat.mul_assoc, Nat.mul_comm, Nat.div_add by + (apply Nat.pow_nonzero; lia). + destruct (n - s) eqn:e; [lia|]. + cbn [Nat.pow]. + rewrite <- Nat.mul_assoc, Nat.mul_comm, Nat.Div0.mod_add by lia. + easy. +Qed. + +Lemma testbit_add_pow2_large (i j s n : nat) (Hs : n <= s) (Hi : i < 2^n) : + Nat.testbit (i + 2^n * j) s = Nat.testbit j (s - n). +Proof. + replace s with (s-n + n) at 1 by lia. + generalize (s - n) as d. + intros d. + rewrite 2!Nat.testbit_eqb. + do 2 f_equal. + rewrite Nat.pow_add_r, (Nat.mul_comm _ (2^_)), Nat.mul_comm, + <- Nat.Div0.div_div, Nat.div_add by + (apply Nat.pow_nonzero; lia). + rewrite (Nat.div_small i) by easy. + easy. +Qed. + +Lemma testbit_add_pow2_split i j n (Hi : i < 2^n) : + forall s, + Nat.testbit (j * 2 ^ n + i) s = + if s + i mod 2^m = j mod 2^m -> i mod 2^n = j mod 2^n. +Proof. + intros Hnm Heq. + replace m with (n + (m - n)) in * by lia. + generalize dependent (m - n). + intros k _. + rewrite Nat.pow_add_r, 2!Nat.Div0.mod_mul_r. + intros H. + apply (f_equal (fun k => k mod 2^n)) in H. + revert H. + rewrite 2!(Nat.mul_comm (2^n)). + rewrite 2!Nat.Div0.mod_add, 2!Nat.Div0.mod_mod. + easy. +Qed. + +Lemma bits_inj_upto i j n : + (forall s, s < n -> Nat.testbit i s = Nat.testbit j s) <-> + i mod 2^n = j mod 2^n. +Proof. + split. + - intros Heq. + induction n; + [now rewrite 2!Nat.mod_1_r|]. + rewrite 2!mod_2_pow_S. + f_equal; [|apply IHn; intros k Hk; apply Heq; lia]. + rewrite Heq by lia. + easy. + - intros Heq s Hs. + rewrite 2!Nat.testbit_eqb. + rewrite (Nat.div_mod i (2^(S s)) ltac:(apply Nat.pow_nonzero; lia)). + rewrite (Nat.div_mod j (2^(S s)) ltac:(apply Nat.pow_nonzero; lia)). + rewrite (mod_pow2_eq_closed_down i j (S s) n ltac:(lia) Heq). + rewrite 2!(Nat.mul_comm (2^ S s)), 2!(Nat.add_comm (_*_)). + rewrite Nat.pow_succ_r by lia. + rewrite 2!Nat.mul_assoc. + rewrite 2!Nat.div_add by (apply Nat.pow_nonzero; lia). + rewrite 2!Nat.Div0.mod_add. + easy. +Qed. + +Lemma lt_pow2_S_log2 i : i < 2 ^ S (Nat.log2 i). +Proof. + destruct i; [cbn; lia|]. + apply Nat.log2_spec; lia. +Qed. + +Lemma bits_inj_upto_small i j n (Hi : i < 2^n) (Hj : j < 2^n) : + (forall s, s < n -> Nat.testbit i s = Nat.testbit j s) <-> + i = j. +Proof. + split; [|intros ->; easy]. + intros H; apply bits_inj_upto in H. + assert (H2n : 2^n <> 0) by (apply Nat.pow_nonzero; lia). + rewrite (Nat.div_mod i (2^n) H2n), (Nat.div_mod j (2^n) H2n). + rewrite 2!Nat.div_small, Nat.mul_0_r by lia. + easy. +Qed. + +Lemma bits_inj i j : + (forall s, Nat.testbit i s = Nat.testbit j s) <-> i = j. +Proof. + split; [|intros ->; easy]. + set (ub := 2^ max (S (Nat.log2 i)) (S (Nat.log2 j))). + assert (Hi : i < ub) by + (enough (i < 2 ^ (S (Nat.log2 i))) by + (pose proof (Nat.pow_le_mono_r 2 (S (Nat.log2 i)) _ + ltac:(easy) (Nat.le_max_l _ (S (Nat.log2 j)))); lia); + apply lt_pow2_S_log2). + assert (Hj : j < ub) by + (enough (j < 2 ^ (S (Nat.log2 j))) by + (pose proof (Nat.pow_le_mono_r 2 (S (Nat.log2 j)) _ + ltac:(easy) (Nat.le_max_r (S (Nat.log2 i)) _)); lia); + apply lt_pow2_S_log2). + intros s. + apply (bits_inj_upto_small i j _ Hi Hj). + intros; easy. +Qed. + +Lemma testbit_make_gap i m k s : + Nat.testbit (i mod 2^m + (i/2^m) * 2^k * (2^m)) s = + if s Nat.testbit m k = false) -> + n + m = Nat.lxor n m. +Proof. + intros Hnm. + apply bits_inj. + intros s. + rewrite lxor_spec. + revert n m Hnm. + induction s; + intros n m Hnm; + [apply Nat.odd_add|]. + simpl. + rewrite !div2_div. + rewrite div_add'. + rewrite <- !bit0_mod. + rewrite (div_small (_ + _)), Nat.add_0_r by + (generalize (Hnm 0); + destruct (testbit n 0), (testbit m 0); + simpl; lia). + apply IHs. + intros k. + rewrite 2!div2_bits; auto. +Qed. + +Lemma testbit_add_disjoint_pow2_l k n : + Nat.testbit n k = false -> + forall i, + Nat.testbit (2^k + n) i = (i =? k) || testbit n i. +Proof. + intros Hn i. + rewrite sum_eq_lxor_of_bits_disj_l, lxor_spec, pow2_bits_eqb, eqb_sym. + - bdestruct (i =? k). + + subst. + now rewrite Hn. + + now destruct (testbit n i). + - intros s. + rewrite pow2_bits_eqb. + bdestructΩ'. +Qed. + +Lemma testbit_sum_pows_2_ne k l : k <> l -> forall i, + Nat.testbit (2 ^ k + 2 ^ l) i = (i =? k) || (i =? l). +Proof. + intros Hkl i. + rewrite testbit_add_disjoint_pow2_l; + rewrite pow2_bits_eqb; bdestructΩ'. +Qed. + +Lemma testbit_add_disjoint m n : + (forall k, Nat.testbit n k = true -> Nat.testbit m k = false) -> + forall i, + Nat.testbit (n + m) i = testbit n i || testbit m i. +Proof. + intros Hn i. + rewrite sum_eq_lxor_of_bits_disj_l, lxor_spec by easy. + generalize (Hn i). + destruct (testbit n i), (testbit m i); lia + auto. +Qed. + +Lemma testbit_b2n b k : + testbit (b2n b) k = b && (k =? 0). +Proof. + destruct b, k; easy + apply Nat.bits_0. +Qed. + +Lemma testbit_decomp n k : + n = (n / 2 ^ (S k)) * 2 ^ (S k) + + b2n (testbit n k) * 2 ^ k + (n mod 2^k). +Proof. + apply bits_inj. + intros s. + rewrite Nat.pow_succ_r, Nat.mul_assoc, <- Nat.mul_add_distr_r by lia. + rewrite testbit_add_pow2_split by show_moddy_lt. + change 2 with (2^1) at 4. + rewrite testbit_add_pow2_split by (destruct (testbit n k); simpl; lia). + rewrite testbit_b2n. + rewrite <- Nat.pow_succ_r by lia. + rewrite testbit_div_pow2, testbit_mod_pow2. + bdestructΩ'; rewrite ?andb_true_r; f_equal; lia. +Qed. + +End nat_lemmas. + Ltac simplify_mods_of a b := - first [ - rewrite (Nat.mod_small a b) in * by lia - | rewrite (mod_n_to_2n a b) in * by lia - ]. + first [ + rewrite (Nat.mod_small a b) in * by lia + | rewrite (mod_n_to_2n a b) in * by lia]. Ltac solve_simple_mod_eqns := let __fail_if_has_mods a := @@ -77,22 +1565,242 @@ Ltac solve_simple_mod_eqns := | _ => idtac end in - match goal with - | |- context[if _ then _ else _] => fail 1 "Cannot solve equation with if" - | _ => - repeat first [ + match goal with + | |- context[if _ then _ else _] => fail 1 "Cannot solve equation with if" + | _ => + repeat first [ easy - | lia - | match goal with - | |- context[?a mod ?b] => __fail_if_has_mods a; __fail_if_has_mods b; - simplify_mods_of a b - | H: context[?a mod ?b] |- _ => __fail_if_has_mods a; __fail_if_has_mods b; - simplify_mods_of a b - end - | match goal with - | |- context[?a mod ?b] => (* idtac a b; *) bdestruct (a + __fail_if_has_mods a; __fail_if_has_mods b; + simplify_mods_of a b + | H: context[?a mod ?b] |- _ => + __fail_if_has_mods a; __fail_if_has_mods b; + simplify_mods_of a b + end + | match goal with + | |- context[?a mod ?b] => (* idtac a b; *) + bdestruct (a + (if b then u else v) = u. +Proof. + bdestructΩ'. +Qed. + +Lemma if_false {A} b (u v : A) : + b = false -> + (if b then u else v) = v. +Proof. + bdestructΩ'. +Qed. + +Lemma if_dist' {A B} (f : A -> B) (b : bool) (x y : A) : + f (if b then x else y) = if b then f x else f y. +Proof. + now destruct b. +Qed. + +Lemma orb_if {A} b c (v v' : A) : + (if (b || c) then v else v') = + if b then v else if c then v else v'. +Proof. + bdestructΩ'. +Qed. + +Lemma f_equal_if {A} (b c : bool) (u v x y : A) : + b = c -> u = v -> x = y -> + (if b then u else x) = (if c then v else y). +Proof. + intros; subst; easy. +Qed. + +Lemma f_equal_if_precedent {A} b c (v1 v2 u1 u2 : A) : + b = c -> + (b = true -> c = true -> v1 = v2) -> + (b = false -> c = false -> u1 = u2) -> + (if b then v1 else u1) = (if c then v2 else u2). +Proof. + intros ->. + destruct c; auto. +Qed. + +Lemma f_equal_if_precedent_same {A} b (v1 v2 u1 u2 : A) : + (b = true -> v1 = v2) -> + (b = false -> u1 = u2) -> + (if b then v1 else u1) = (if b then v2 else u2). +Proof. + intros. + apply f_equal_if_precedent; auto. +Qed. + +Lemma and_same (P : Prop) : P /\ P <-> P. +Proof. split; try intros []; auto. Qed. + +Local Open Scope nat_scope. + +Lemma and_andb {P P'} {b b'} : + reflect P b -> reflect P' b' -> + reflect (P /\ P') (b && b'). +Proof. + intros H H'; apply reflect_iff in H, H'. + apply iff_reflect. + rewrite andb_true_iff. + now rewrite H, H'. +Qed. + +Lemma forall_iff {A} (f g : A -> Prop) : + (forall a, (f a <-> g a)) -> + ((forall a, f a) <-> (forall a, g a)). +Proof. + intros ?; split; intros; apply H; auto. +Qed. + +Lemma impl_iff (P Q Q' : Prop) : + ((P -> Q) <-> (P -> Q')) <-> + (P -> (Q <-> Q')). +Proof. + split; + intros ?; split; intros; apply H; auto. +Qed. + +Import Setoid. + +Lemma Forall_forallb {A} (f : A -> bool) (P : A -> Prop) + (Hf : forall a, P a <-> f a = true) : + forall l, Forall P l <-> forallb f l = true. +Proof. + intros l. + induction l; [repeat constructor|]. + simpl. + rewrite andb_true_iff. + rewrite Forall_cons_iff. + apply Morphisms_Prop.and_iff_morphism; easy. +Qed. + +Lemma eq_eqb_iff (b c : bool) : + b = c <-> eqb b c = true. +Proof. + destruct b, c ; easy. +Qed. + +Lemma eqb_true_l b : eqb true b = b. +Proof. now destruct b. Qed. + +Lemma eqb_true_r b : eqb b true = b. +Proof. now destruct b. Qed. + +Lemma Forall_seq {start len : nat} f : + Forall f (seq start len) <-> forall k, k < len -> f (start + k). +Proof. + revert start; + induction len; intros start; + [split; constructor + lia|]. + simpl. + rewrite Forall_cons_iff. + split. + - intros [Hfk H]. + rewrite IHlen in H. + intros k Hk. + destruct k. + + rewrite Nat.add_0_r; easy. + + specialize (H k). + rewrite Nat.add_succ_r. + apply H. + lia. + - intros H. + rewrite IHlen; split. + + specialize (H 0). + rewrite Nat.add_0_r in H. + apply H; lia. + + intros k Hk; specialize (H (S k)). + rewrite Nat.add_succ_r in H. + apply H. + lia. +Qed. + +Lemma Forall_seq0 {len : nat} f : + Forall f (seq 0 len) <-> forall k, k < len -> f k. +Proof. + apply (@Forall_seq 0 len f). +Qed. + +Lemma forallb_seq (f : nat -> bool) n m : + forallb f (seq m n) = true <-> + (forall s, s < n -> f (s + m) = true). +Proof. + revert m; + induction n; intros m; [easy|]. + simpl. + rewrite andb_true_iff, IHn. + split. + - intros [Hm Hlt]. + intros s. + destruct s; [easy|]. + setoid_rewrite Nat.add_succ_r in Hlt. + intros. + apply Hlt; lia. + - intros Hlt; split. + + apply (Hlt 0 ltac:(lia)). + + intros s Hs. + rewrite Nat.add_succ_r. + apply (Hlt (S s)). + lia. +Qed. + +Lemma forallb_seq0 (f : nat -> bool) n : + forallb f (seq 0 n) = true <-> + (forall s, s < n -> f s = true). +Proof. + rewrite forallb_seq. + now setoid_rewrite Nat.add_0_r. +Qed. + +Lemma forall_lt_sum_split n m (P : nat -> Prop) : + (forall k, k < n + m -> P k) <-> + (forall k, k < n -> P k) /\ (forall k, k < m -> P (n + k)). +Proof. + split; [intros H; split; intros; apply H; lia|]. + intros [Hlow Hhigh]. + intros. + bdestruct (k P. +Proof. tauto. Qed. + +Lemma and_True_r P : P /\ True <-> P. +Proof. tauto. Qed. + +Lemma and_iff_distr_l (P Q R : Prop) : + (P -> (Q <-> R)) <-> (P /\ Q <-> P /\ R). +Proof. tauto. Qed. + +Lemma and_iff_distr_r (P Q R : Prop) : + (P -> (Q <-> R)) <-> (Q /\ P <-> R /\ P). +Proof. rewrite and_iff_distr_l. now rewrite 2!(and_comm P). Qed. + +End Assorted_lemmas. \ No newline at end of file diff --git a/Pad.v b/Pad.v index 88c7ba9..bb02e65 100644 --- a/Pad.v +++ b/Pad.v @@ -351,7 +351,7 @@ Proof. trivial. all : simpl; NoDupity; auto with wf_db. all : rewrite perm_swap; apply perm_skip; apply perm_swap. - Qed. +Qed. (** Unitarity *) @@ -363,10 +363,12 @@ Proof. intros n u start dim B [WF U]. split. apply WF_pad; auto. unfold pad. - gridify. + Modulus.bdestructΩ'. Msimpl. rewrite U. - reflexivity. + rewrite 2!id_kron. + f_equal. + unify_pows_two. Qed. Lemma pad_u_unitary : forall dim n u, @@ -419,6 +421,15 @@ Proof. apply pad_ctrl_unitary; auto; apply σx_unitary. Qed. +Lemma pad_u_mmult : forall dim b A B, WF_Matrix A -> WF_Matrix B -> + pad_u dim b (A × B) = pad_u dim b A × pad_u dim b B. +Proof. + intros. + unfold pad_u, pad. + bdestruct_all; now Msimpl. +Qed. + + (** Lemmas about commutation *) Lemma pad_A_B_commutes : forall dim m n A B, diff --git a/PermutationAutomation.v b/PermutationAutomation.v index 8438d09..20e05b2 100644 --- a/PermutationAutomation.v +++ b/PermutationAutomation.v @@ -1,1111 +1,520 @@ Require Import Bits. -Require Import VectorStates. Require Import Modulus. -Require Import Permutations. +Require Export PermutationsBase. Local Open Scope perm_scope. Local Open Scope nat_scope. -(* Stack and swap perms definitions *) -Definition stack_perms (n0 n1 : nat) (f g : nat -> nat) : nat -> nat := - fun n => - if (n nat := - fun n => if 2 <=? n then n else match n with - | 0 => 1%nat - | 1 => 0%nat - | other => other - end. - -Definition swap_perm a b n := - fun k => if n <=? k then k else - if k =? a then b else - if k =? b then a else k. - -Definition rotr n m : nat -> nat := - fun k => if n <=? k then k else (k + m) mod n. - -Definition rotl n m : nat -> nat := - fun k => if n <=? k then k else (k + (n - (m mod n))) mod n. - -Ltac bdestruct_one := - let fail_if_iffy H := - match H with - | context[if _ then _ else _] => fail 1 - | _ => idtac - end - in - match goal with - | |- context [ ?a fail_if_iffy a; fail_if_iffy b; bdestruct (a fail_if_iffy a; fail_if_iffy b; bdestruct (a <=? b) - | |- context [ ?a =? ?b ] => fail_if_iffy a; fail_if_iffy b; bdestruct (a =? b) - | |- context[if ?b then _ else _] - => fail_if_iffy b; destruct b eqn:? - end. - -Ltac bdestructΩ' := - let tryeasylia := try easy; try lia in - repeat (bdestruct_one; subst; tryeasylia); - tryeasylia. - -Tactic Notation "cleanup_perm_inv" := - autorewrite with perm_inv_db. - -Tactic Notation "cleanup_perm" := - autorewrite with perm_inv_db perm_cleanup_db. +Create HintDb perm_unfold_db. +Create HintDb perm_cleanup_db. +Create HintDb proper_side_conditions_db. -Tactic Notation "cleanup_perm_of_zx" := - autounfold with zxperm_db; - autorewrite with perm_of_zx_cleanup_db perm_inv_db perm_cleanup_db. +Ltac auto_perm_to n := + auto n with perm_db perm_bounded_db WF_Perm_db perm_inv_db. -Lemma compose_id_of_compose_idn {f g : nat -> nat} - (H : (f ∘ g)%prg = (fun n => n)) {k : nat} : f (g k) = k. -Proof. - apply (f_equal_inv k) in H. - easy. -Qed. +Ltac auto_perm := + auto 6 with perm_db perm_bounded_db WF_Perm_db perm_inv_db. -Ltac perm_by_inverse finv := - let tryeasylia := try easy; try lia in - exists finv; - intros k Hk; repeat split; - only 3,4 : (try apply compose_id_of_compose_idn; cleanup_perm; tryeasylia) - || cleanup_perm; tryeasylia; - only 1,2 : auto with perm_bounded_db; tryeasylia. +Tactic Notation "auto_perm" int_or_var(n) := + auto_perm_to n. -(* Section on swap_perm, swaps two elements *) -Lemma swap_perm_same a n : - swap_perm a a n = idn. -Proof. - unfold swap_perm. - apply functional_extensionality; intros k. - bdestructΩ'. -Qed. +Tactic Notation "auto_perm" := + auto_perm 6. -#[export] Hint Rewrite swap_perm_same : perm_cleanup_db. +#[export] Hint Resolve + permutation_is_bounded + permutation_is_injective + permutation_is_surjective : perm_db. -Lemma swap_perm_comm a b n : - swap_perm a b n = swap_perm b a n. -Proof. - apply functional_extensionality; intros k. - unfold swap_perm. - bdestructΩ'. -Qed. +#[export] Hint Extern 0 (perm_inj ?n ?f) => + apply (permutation_is_injective n f) : perm_db. -Lemma swap_WF_perm a b n : forall k, n <= k -> swap_perm a b n k = k. -Proof. - intros. - unfold swap_perm. - bdestructΩ'. -Qed. +#[export] Hint Resolve + permutation_compose : perm_db. -#[export] Hint Resolve swap_WF_perm : WF_perm_db. +#[export] Hint Resolve compose_WF_Perm : WF_Perm_db. +#[export] Hint Rewrite @compose_idn_r @compose_idn_l : perm_cleanup_db. -Lemma swap_perm_bounded a b n : a < n -> b < n -> - forall k, k < n -> swap_perm a b n k < n. -Proof. - intros Ha Hb k Hk. - unfold swap_perm. - bdestructΩ'. -Qed. +#[export] Hint Extern 100 (_ < _) => + show_moddy_lt : perm_bounded_db. -#[export] Hint Resolve swap_perm_bounded : perm_bounded_db. +#[export] Hint Extern 0 (funbool_to_nat ?n ?f < ?b) => + apply (Nat.lt_le_trans (Bits.funbool_to_nat n f) (2^n) b); + [apply (Bits.funbool_to_nat_bound n f) | show_pow2_le] : show_moddy_lt_db. -Lemma swap_perm_inv a b n : a < n -> b < n -> - ((swap_perm a b n) ∘ (swap_perm a b n))%prg = idn. -Proof. - intros Ha Hb. - unfold compose. - apply functional_extensionality; intros k. - unfold swap_perm. - bdestructΩ'. -Qed. - -#[export] Hint Rewrite swap_perm_inv : perm_inv_db. - -Lemma swap_perm_2_perm a b n : a < n -> b < n -> - permutation n (swap_perm a b n). -Proof. - intros Ha Hb. - perm_by_inverse (swap_perm a b n). -Qed. - -#[export] Hint Resolve swap_perm_2_perm : perm_db. - -Lemma swap_perm_S_permutation a n (Ha : S a < n) : - permutation n (swap_perm a (S a) n). -Proof. - apply swap_perm_2_perm; lia. -Qed. - -#[export] Hint Resolve swap_perm_S_permutation : perm_db. - -Lemma compose_swap_perm a b c n : a < n -> b < n -> c < n -> - b <> c -> a <> c -> - (swap_perm a b n ∘ swap_perm b c n ∘ swap_perm a b n)%prg = swap_perm a c n. -Proof. - intros Ha Hb Hc Hbc Hac. - apply functional_extensionality; intros k. - unfold compose, swap_perm. - bdestructΩ'. -Qed. - -#[export] Hint Rewrite compose_swap_perm : perm_cleanup_db. - -(* Section for swap_2_perm *) -Lemma swap_2_perm_inv : - (swap_2_perm ∘ swap_2_perm)%prg = idn. -Proof. - apply functional_extensionality; intros k. - repeat first [easy | destruct k]. -Qed. - -#[export] Hint Rewrite swap_2_perm_inv : perm_inv_db. - -Lemma swap_2_perm_bounded k : - k < 2 -> swap_2_perm k < 2. -Proof. - intros Hk. - repeat first [easy | destruct k | cbn; lia]. -Qed. - -#[export] Hint Resolve swap_2_perm_bounded : perm_bounded_db. - -Lemma swap_2_WF_perm k : 1 < k -> swap_2_perm k = k. -Proof. - intros. - repeat first [easy | lia | destruct k]. -Qed. - -Global Hint Resolve swap_2_WF_perm : WF_perm_db. +Ltac show_permutation := + repeat first [ + split + | simpl; solve [auto with perm_db] + | subst; solve [auto with perm_db] + | solve [eauto using permutation_compose with perm_db] + | easy + | lia + ]. -Lemma swap_2_perm_permutation : permutation 2 swap_2_perm. -Proof. - perm_by_inverse swap_2_perm. -Qed. +Ltac cleanup_perm_inv := + auto with perm_inv_db perm_db perm_bounded_db WF_Perm_db; + autorewrite with perm_inv_db; + auto with perm_inv_db perm_db perm_bounded_db WF_Perm_db. -Global Hint Resolve swap_2_perm_permutation : perm_db. +Ltac cleanup_perm := + auto with perm_inv_db perm_cleanup_db perm_db perm_bounded_db WF_Perm_db; + autorewrite with perm_inv_db perm_cleanup_db; + auto with perm_inv_db perm_cleanup_db perm_db perm_bounded_db WF_Perm_db. -(* Section for stack_perms *) + Ltac solve_modular_permutation_equalities := - first [cleanup_perm_of_zx | cleanup_perm_inv | cleanup_perm]; - unfold Basics.compose, rotr, rotl, stack_perms, swap_perm, - (* TODO: remove *) swap_2_perm; + first [cleanup_perm_inv | cleanup_perm]; + autounfold with perm_unfold_db; apply functional_extensionality; let k := fresh "k" in intros k; bdestructΩ'; solve_simple_mod_eqns. -Lemma stack_perms_WF_idn {n0 n1} {f} - (H : forall k, n0 <= k -> f k = k): - stack_perms n0 n1 f idn = f. -Proof. - solve_modular_permutation_equalities; - rewrite H; lia. -Qed. - -Lemma stack_perms_WF {n0 n1} {f g} k : - n0 + n1 <= k -> stack_perms n0 n1 f g k = k. -Proof. - intros H. - unfold stack_perms. - bdestructΩ'. -Qed. - -Global Hint Resolve stack_perms_WF : WF_perm_db. - -Lemma stack_perms_bounded {n0 n1} {f g} - (Hf : forall k, k < n0 -> f k < n0) (Hg : forall k, k < n1 -> g k < n1) : - forall k, k < n0 + n1 -> stack_perms n0 n1 f g k < n0 + n1. -Proof. - intros k Hk. - unfold stack_perms. - bdestruct (k (f k < n0 /\ finv k < n0 /\ finv (f k) = k /\ f (finv k) = k)) - (Hg: forall k, k < n1 -> (g k < n1 /\ ginv k < n1 /\ ginv (g k) = k /\ g (ginv k) = k)) : - (stack_perms n0 n1 f g ∘ stack_perms n0 n1 finv ginv)%prg = idn. -Proof. - unfold compose. - solve_modular_permutation_equalities. - 1-3: specialize (Hf _ H); lia. - - replace (ginv (k - n0) + n0 - n0) with (ginv (k - n0)) by lia. - assert (Hkn0: k - n0 < n1) by lia. - specialize (Hg _ Hkn0). - lia. - - assert (Hkn0: k - n0 < n1) by lia. - specialize (Hg _ Hkn0). - lia. -Qed. - -Lemma is_inv_iff_inv_is n f finv : - (forall k, k < n -> finv k < n /\ f k < n /\ f (finv k) = k /\ finv (f k) = k)%nat - <-> (forall k, k < n -> f k < n /\ finv k < n /\ finv (f k) = k /\ f (finv k) = k)%nat. -Proof. - split; intros H k Hk; specialize (H k Hk); easy. -Qed. - -#[export] Hint Rewrite is_inv_iff_inv_is : perm_inv_db. - -Lemma stack_perms_linv {n0 n1} {f g} {finv ginv} - (Hf: forall k, k < n0 -> (f k < n0 /\ finv k < n0 /\ finv (f k) = k /\ f (finv k) = k)) - (Hg: forall k, k < n1 -> (g k < n1 /\ ginv k < n1 /\ ginv (g k) = k /\ g (ginv k) = k)) : - (stack_perms n0 n1 finv ginv ∘ stack_perms n0 n1 f g)%prg = idn. -Proof. - rewrite stack_perms_rinv. - 2,3: rewrite is_inv_iff_inv_is. - all: easy. -Qed. - -#[export] Hint Rewrite @stack_perms_rinv @stack_perms_linv : perm_inv_db. - -Lemma stack_perms_permutation {n0 n1 f g} (Hf : permutation n0 f) (Hg: permutation n1 g) : - permutation (n0 + n1) (stack_perms n0 n1 f g). -Proof. - destruct Hf as [f' Hf']. - destruct Hg as [g' Hg']. - perm_by_inverse (stack_perms n0 n1 f' g'). - 1,2: apply stack_perms_bounded; try easy; intros k' Hk'; - try specialize (Hf' _ Hk'); try specialize (Hg' _ Hk'); easy. - 1,2: rewrite is_inv_iff_inv_is; easy. -Qed. - -Global Hint Resolve stack_perms_permutation : perm_db. - -(* Section on insertion_sort_list *) -Fixpoint insertion_sort_list n f := - match n with - | 0 => [] - | S n' => let k := (perm_inv (S n') f n') in - k :: insertion_sort_list n' (fswap f k n') - end. - -Fixpoint swap_list_spec l : bool := - match l with - | [] => true - | k :: ks => (k idn - | k :: ks => let n := length ks in - (swap_perm k n (S n) ∘ (perm_of_swap_list ks))%prg - end. - -Fixpoint invperm_of_swap_list l := - match l with - | [] => idn - | k :: ks => let n := length ks in - ((invperm_of_swap_list ks) ∘ swap_perm k n (S n))%prg - end. - -Lemma fswap_eq_compose_swap_perm {A} (f : nat -> A) n m o : n < o -> m < o -> - fswap f n m = (f ∘ swap_perm n m o)%prg. -Proof. - intros Hn Hm. - apply functional_extensionality; intros k. - unfold compose, fswap, swap_perm. - bdestruct_all; easy. -Qed. - -Lemma fswap_perm_inv_n_permutation f n : permutation (S n) f -> - permutation n (fswap f (perm_inv (S n) f n) n). -Proof. - intros Hperm. - apply fswap_at_boundary_permutation. - - apply Hperm. - - apply perm_inv_bounded_S. - - apply perm_inv_is_rinv_of_permutation; auto. -Qed. - -Lemma perm_of_swap_list_WF l : swap_list_spec l = true -> - WF_Perm (length l) (perm_of_swap_list l). -Proof. - induction l. - - easy. - - simpl. - rewrite andb_true_iff. - intros [Ha Hl]. - intros k Hk. - unfold compose. - rewrite IHl; [|easy|lia]. - rewrite swap_WF_perm; easy. -Qed. - -Lemma invperm_of_swap_list_WF l : swap_list_spec l = true -> - WF_Perm (length l) (invperm_of_swap_list l). +Lemma compose_id_of_compose_idn {f g : nat -> nat} + (H : (f ∘ g)%prg = (fun n => n)) {k : nat} : f (g k) = k. Proof. - induction l. - - easy. - - simpl. - rewrite andb_true_iff. - intros [Ha Hl]. - intros k Hk. - unfold compose. - rewrite swap_WF_perm; [|easy]. - rewrite IHl; [easy|easy|lia]. + apply (f_equal_inv k) in H. + easy. Qed. -#[export] Hint Resolve perm_of_swap_list_WF invperm_of_swap_list_WF : WF_perm_db. - -Lemma perm_of_swap_list_bounded l : swap_list_spec l = true -> - perm_bounded (length l) (perm_of_swap_list l). -Proof. - induction l; [easy|]. - simpl. - rewrite andb_true_iff. - intros [Ha Hl]. - intros k Hk. - unfold compose. - rewrite Nat.ltb_lt in Ha. - apply swap_perm_bounded; try lia. - bdestruct (k =? length l). - - subst; rewrite perm_of_swap_list_WF; try easy; lia. - - transitivity (length l); [|lia]. - apply IHl; [easy | lia]. -Qed. +Ltac perm_by_inverse finv := + let tryeasylia := try easy; try lia in + exists finv; + intros k Hk; repeat split; + only 3,4 : + (solve [apply compose_id_of_compose_idn; cleanup_perm; tryeasylia]) + || cleanup_perm; tryeasylia; + only 1,2 : auto with perm_bounded_db; tryeasylia. -Lemma invperm_of_swap_list_bounded l : swap_list_spec l = true -> - perm_bounded (length l) (invperm_of_swap_list l). -Proof. - induction l; [easy|]. - simpl. - rewrite andb_true_iff. - intros [Ha Hl]. - rewrite Nat.ltb_lt in Ha. - intros k Hk. - unfold compose. - bdestruct (swap_perm a (length l) (S (length l)) k =? length l). - - rewrite H, invperm_of_swap_list_WF; [lia|easy|easy]. - - transitivity (length l); [|lia]. - apply IHl; [easy|]. - pose proof (swap_perm_bounded a (length l) (S (length l)) Ha (ltac:(lia)) k Hk). - lia. -Qed. +Ltac permutation_eq_by_WF_inv_inj f n := + let tryeasylia := (try easy); (try lia) in + apply (WF_permutation_inverse_injective f n); [ + tryeasylia; auto with perm_db | + tryeasylia; auto with WF_Perm_db | + try solve [cleanup_perm; auto] | + try solve [cleanup_perm; auto]]; + tryeasylia. -#[export] Hint Resolve perm_of_swap_list_bounded invperm_of_swap_list_bounded : perm_bounded_db. +Ltac perm_eq_by_inv_inj f n := + let tryeasylia := (try easy); (try lia) in + apply (perm_inv_perm_eq_injective f n); [ + tryeasylia; auto with perm_db | + try solve [cleanup_perm; auto] | + try solve [cleanup_perm; auto]]; + tryeasylia. -Lemma invperm_linv_perm_of_swap_list l : swap_list_spec l = true -> - (invperm_of_swap_list l ∘ perm_of_swap_list l)%prg = idn. -Proof. - induction l. - - easy. - - simpl. - rewrite andb_true_iff. - intros [Ha Hl]. - rewrite Combinators.compose_assoc, - <- (Combinators.compose_assoc _ _ _ _ (perm_of_swap_list _)). - rewrite swap_perm_inv, compose_idn_l. - + apply (IHl Hl). - + bdestructΩ (a - (perm_of_swap_list l ∘ invperm_of_swap_list l)%prg = idn. -Proof. - induction l. - - easy. - - simpl. - rewrite andb_true_iff. - intros [Ha Hl]. - rewrite <- Combinators.compose_assoc, - (Combinators.compose_assoc _ _ _ _ (invperm_of_swap_list _)). - rewrite (IHl Hl). - rewrite compose_idn_r. - rewrite swap_perm_inv; [easy| |lia]. - bdestructΩ (a True. -Local Opaque perm_inv. -Lemma insertion_sort_list_is_swap_list n f : - swap_list_spec (insertion_sort_list n f) = true. -Proof. - revert f; - induction n; - intros f. - - easy. - - simpl. - rewrite length_insertion_sort_list, IHn. - pose proof (perm_inv_bounded_S n f n). - bdestructΩ (perm_inv (S n) f n - forall k, k < n -> - (f ∘ perm_of_swap_list (insertion_sort_list n f))%prg k = k. -Proof. - revert f; - induction n; - intros f. - - intros; exfalso; easy. - - intros Hperm k Hk. - simpl. - rewrite length_insertion_sort_list. - bdestruct (k =? n). - + unfold compose. - rewrite perm_of_swap_list_WF; [ | - apply insertion_sort_list_is_swap_list | - rewrite length_insertion_sort_list; lia - ]. - unfold swap_perm. - bdestructΩ (S n <=? k). - bdestructΩ (k =? n). - subst. - bdestruct (n =? perm_inv (S n) f n). - 1: rewrite H at 1. - all: rewrite perm_inv_is_rinv_of_permutation; [easy|easy|lia]. - + rewrite <- Combinators.compose_assoc. - rewrite <- fswap_eq_compose_swap_perm; [|apply perm_inv_bounded_S|lia]. - rewrite IHn; [easy| |lia]. - apply fswap_perm_inv_n_permutation, Hperm. -Qed. -Local Transparent perm_inv. +#[export] Hint Unfold true_rel : typeclass_instances. -Lemma perm_of_insertion_sort_list_WF n f : - WF_Perm n (perm_of_swap_list (insertion_sort_list n f)). +#[export] Instance true_rel_superrel {A} (R : relation A) : + subrelation R true_rel | 10. Proof. - intros k. - rewrite <- (length_insertion_sort_list n f) at 1. - revert k. - apply perm_of_swap_list_WF. - apply insertion_sort_list_is_swap_list. + intros x y H. + constructor. Qed. -Lemma invperm_of_insertion_sort_list_WF n f : - WF_Perm n (invperm_of_swap_list (insertion_sort_list n f)). -Proof. - intros k. - rewrite <- (length_insertion_sort_list n f) at 1. - revert k. - apply invperm_of_swap_list_WF. - apply insertion_sort_list_is_swap_list. -Qed. +Definition on_predicate_relation_l {A} (P : A -> Prop) (R : relation A) + : relation A := + fun (x y : A) => P x /\ R x y. -#[export] Hint Resolve perm_of_insertion_sort_list_WF invperm_of_swap_list_WF : WF_perm_db. +Definition on_predicate_relation_r {A} (P : A -> Prop) (R : relation A) + : relation A := + fun (x y : A) => P y /\ R x y. -Lemma perm_of_insertion_sort_list_perm_eq_perm_inv n f : permutation n f -> - perm_eq n (perm_of_swap_list (insertion_sort_list n f)) (perm_inv n f). +Lemma proper_proxy_on_predicate_l {A} (P : A -> Prop) (R : relation A) + (x : A) : + P x -> + Morphisms.ProperProxy R x -> + Morphisms.ProperProxy (on_predicate_relation_l P R) x. Proof. - intros Hperm. - apply (perm_bounded_rinv_injective_of_injective n f). - - apply permutation_is_injective, Hperm. - - pose proof (perm_of_swap_list_bounded (insertion_sort_list n f) - (insertion_sort_list_is_swap_list n f)) as H. - rewrite (length_insertion_sort_list n f) in H. - exact H. - - auto with perm_bounded_db. - - apply perm_of_insertion_sort_list_is_rinv, Hperm. - - apply perm_inv_is_rinv_of_permutation, Hperm. + easy. Qed. -Lemma perm_of_insertion_sort_list_eq_make_WF_perm_inv n f : permutation n f -> - (perm_of_swap_list (insertion_sort_list n f)) = fun k => if n <=?k then k else perm_inv n f k. +Lemma proper_proxy_on_predicate_r {A} (P : A -> Prop) (R : relation A) + (x : A) : + P x -> + Morphisms.ProperProxy R x -> + Morphisms.ProperProxy (on_predicate_relation_r P R) x. Proof. - intros Hperm. - apply functional_extensionality; intros k. - bdestruct (n <=? k). - - rewrite perm_of_insertion_sort_list_WF; easy. - - rewrite perm_of_insertion_sort_list_perm_eq_perm_inv; easy. + easy. Qed. -Lemma perm_eq_linv_injective n f finv finv' : permutation n f -> - is_perm_linv n f finv -> is_perm_linv n f finv' -> - perm_eq n finv finv'. +Lemma proper_proxy_flip_on_predicate_l {A} (P : A -> Prop) (R : relation A) + (x : A) : + P x -> + Morphisms.ProperProxy R x -> + Morphisms.ProperProxy (flip (on_predicate_relation_l P R)) x. Proof. - intros Hperm Hfinv Hfinv' k Hk. - destruct (permutation_is_surjective n f Hperm k Hk) as [k' [Hk' Hfk']]. - unfold compose in *. - specialize (Hfinv k' Hk'). - specialize (Hfinv' k' Hk'). - rewrite Hfk' in *. - rewrite Hfinv, Hfinv'. easy. Qed. -Lemma perm_inv_eq_inv n f finv : - (forall x : nat, x < n -> f x < n /\ finv x < n /\ finv (f x) = x /\ f (finv x) = x) - -> perm_eq n (perm_inv n f) finv. +Lemma proper_proxy_flip_on_predicate_r {A} (P : A -> Prop) (R : relation A) + (x : A) : + P x -> + Morphisms.ProperProxy R x -> + Morphisms.ProperProxy (flip (on_predicate_relation_r P R)) x. Proof. - intros Hfinv. - assert (Hperm: permutation n f) by (exists finv; easy). - apply (perm_eq_linv_injective n f); [easy| | ]; - unfold compose; intros k Hk. - - rewrite perm_inv_is_linv_of_permutation; easy. - - apply Hfinv, Hk. + easy. Qed. -Lemma perm_inv_is_inv n f : permutation n f -> - forall k : nat, k < n -> perm_inv n f k < n /\ f k < n - /\ f (perm_inv n f k) = k /\ perm_inv n f (f k) = k. -Proof. - intros Hperm k Hk. - repeat split. - - apply perm_inv_bounded, Hk. - - destruct Hperm as [? H]; apply H, Hk. - - rewrite perm_inv_is_rinv_of_permutation; easy. - - rewrite perm_inv_is_linv_of_permutation; easy. -Qed. +#[export] Hint Extern 0 + (Morphisms.ProperProxy (on_predicate_relation_l ?P ?R) ?x) => + apply (proper_proxy_on_predicate_l P R x + ltac:(auto with proper_side_conditions_db)) : typeclass_instances. -Lemma perm_inv_perm_inv n f : permutation n f -> - perm_eq n (perm_inv n (perm_inv n f)) f. -Proof. - intros Hperm k Hk. - unfold compose. - rewrite (perm_inv_eq_inv n (perm_inv n f) f); try easy. - apply perm_inv_is_inv, Hperm. -Qed. +#[export] Hint Extern 0 + (Morphisms.ProperProxy (on_predicate_relation_r ?P ?R) ?x) => + apply (proper_proxy_on_predicate_r P R x + ltac:(auto with proper_side_conditions_db)) : typeclass_instances. -Lemma perm_inv_eq_of_perm_eq' n m f g : perm_eq n f g -> m <= n -> - perm_eq n (perm_inv m f) (perm_inv m g). -Proof. - intros Heq Hm. - induction m; [trivial|]. - intros k Hk. - simpl. - rewrite Heq by lia. - rewrite IHm by lia. - easy. -Qed. +#[export] Hint Extern 0 + (Morphisms.ProperProxy (flip (on_predicate_relation_l ?P ?R)) ?x) => + apply (proper_proxy_flip_on_predicate_l P R x + ltac:(auto with proper_side_conditions_db)) : typeclass_instances. -Lemma perm_inv_eq_of_perm_eq n f g : perm_eq n f g -> - perm_eq n (perm_inv n f) (perm_inv n g). -Proof. - intros Heq. - apply perm_inv_eq_of_perm_eq'; easy. -Qed. +#[export] Hint Extern 0 + (Morphisms.ProperProxy (flip (on_predicate_relation_r ?P ?R)) ?x) => + apply (proper_proxy_flip_on_predicate_r P R x + ltac:(auto with proper_side_conditions_db)) : typeclass_instances. -Lemma perm_inv_of_insertion_sort_list_eq n f : permutation n f -> - perm_eq n f (perm_inv n (perm_of_swap_list (insertion_sort_list n f))). -Proof. - intros Hperm k Hk. - rewrite (perm_of_insertion_sort_list_eq_make_WF_perm_inv n f) by easy. - rewrite (perm_inv_eq_of_perm_eq n _ (perm_inv n f)); [ - | intros; bdestructΩ' | easy]. - rewrite perm_inv_perm_inv; easy. -Qed. +#[export] Hint Extern 10 => cbn beta : proper_side_conditions_db. -Lemma perm_of_insertion_sort_list_of_perm_inv_eq n f : permutation n f -> - perm_eq n f (perm_of_swap_list (insertion_sort_list n (perm_inv n f))). -Proof. - intros Hperm. - rewrite perm_of_insertion_sort_list_eq_make_WF_perm_inv by (auto with perm_db). - intros; bdestructΩ'. - rewrite perm_inv_perm_inv; easy. -Qed. +Add Parametric Relation n : (nat -> nat) (perm_eq n) + reflexivity proved by (perm_eq_refl n) + symmetry proved by (@perm_eq_sym n) + transitivity proved by (@perm_eq_trans n) + as perm_eq_Setoid. -Lemma insertion_sort_list_S n f : - insertion_sort_list (S n) f = - (perm_inv (S n) f n) :: (insertion_sort_list n (fswap f (perm_inv (S n) f n) n)). -Proof. easy. Qed. - -Lemma perm_of_swap_list_cons a l : - perm_of_swap_list (a :: l) = (swap_perm a (length l) (S (length l)) ∘ perm_of_swap_list l)%prg. -Proof. easy. Qed. - -Lemma invperm_of_swap_list_cons a l : - invperm_of_swap_list (a :: l) = (invperm_of_swap_list l ∘ swap_perm a (length l) (S (length l)))%prg. -Proof. easy. Qed. - -Lemma perm_of_insertion_sort_list_S n f : - perm_of_swap_list (insertion_sort_list (S n) f) = - (swap_perm (perm_inv (S n) f n) n (S n) ∘ - perm_of_swap_list (insertion_sort_list n (fswap f (perm_inv (S n) f n) n)))%prg. -Proof. - rewrite insertion_sort_list_S, perm_of_swap_list_cons. - rewrite length_insertion_sort_list. - easy. -Qed. +#[export] Hint Extern 0 (perm_bounded _ _) => + solve [auto with perm_bounded_db perm_db] : proper_side_conditions_db. -Lemma invperm_of_insertion_sort_list_S n f : - invperm_of_swap_list (insertion_sort_list (S n) f) = - (invperm_of_swap_list (insertion_sort_list n (fswap f (perm_inv (S n) f n) n)) - ∘ swap_perm (perm_inv (S n) f n) n (S n))%prg. -Proof. - rewrite insertion_sort_list_S, invperm_of_swap_list_cons. - rewrite length_insertion_sort_list. - easy. -Qed. +#[export] Hint Extern 0 (permutation _ _) => + solve [auto with perm_db] : proper_side_conditions_db. -Lemma perm_of_swap_list_permutation l : swap_list_spec l = true -> - permutation (length l) (perm_of_swap_list l). -Proof. - intros Hsw. - induction l; - [ simpl; exists idn; easy |]. - simpl. - apply permutation_compose. - - apply swap_perm_2_perm; [|lia]. - simpl in Hsw. - bdestruct (a + solve [auto with WF_Perm_db] : proper_side_conditions_db. -Lemma invperm_of_swap_list_permutation l : swap_list_spec l = true -> - permutation (length l) (invperm_of_swap_list l). -Proof. - intros Hsw. - induction l; - [ simpl; exists idn; easy |]. - simpl. - apply permutation_compose. - - eapply permutation_monotonic_of_WF. - 2: apply IHl. - 1: lia. - 2: apply invperm_of_swap_list_WF. - all: simpl in Hsw; - rewrite andb_true_iff in Hsw; easy. - - apply swap_perm_2_perm; [|lia]. - simpl in Hsw. - bdestruct (a - perm_eq n f (invperm_of_swap_list (insertion_sort_list n f)). -Proof. - intros Hperm. - apply (perm_eq_linv_injective n (perm_of_swap_list (insertion_sort_list n f))). - - auto with perm_db. - - intros k Hk. - rewrite perm_of_insertion_sort_list_is_rinv; easy. - - intros k Hk. - rewrite invperm_linv_perm_of_swap_list; [easy|]. - apply insertion_sort_list_is_swap_list. -Qed. - -Lemma permutation_grow_l' n f : permutation (S n) f -> - perm_eq (S n) f (swap_perm (f n) n (S n) ∘ - perm_of_swap_list (insertion_sort_list n (fswap (perm_inv (S n) f) (f n) n)))%prg. +Lemma perm_eq_compose_rewrite_l {n} {f g h : nat -> nat} + (H : perm_eq n (f ∘ g) (h)) : forall (i : nat -> nat), + perm_eq n (i ∘ f ∘ g) (i ∘ h). Proof. - intros Hperm k Hk. - rewrite (perm_of_insertion_sort_list_of_perm_inv_eq _ _ Hperm) at 1 by auto. -Local Opaque perm_inv. - simpl. -Local Transparent perm_inv. - rewrite length_insertion_sort_list, perm_inv_perm_inv by auto. - easy. + intros i k Hk. + unfold compose in *. + now rewrite H. Qed. -Lemma permutation_grow_r' n f : permutation (S n) f -> - perm_eq (S n) f ( - invperm_of_swap_list (insertion_sort_list n (fswap f (perm_inv (S n) f n) n)) - ∘ swap_perm (perm_inv (S n) f n) n (S n))%prg. +Lemma perm_eq_compose_rewrite_l_to_2 {n} {f g h i : nat -> nat} + (H : perm_eq n (f ∘ g) (h ∘ i)) : forall (j : nat -> nat), + perm_eq n (j ∘ f ∘ g) (j ∘ h ∘ i). Proof. - intros Hperm k Hk. - rewrite (invperm_of_insertion_sort_list_eq _ _ Hperm) at 1 by auto. -Local Opaque perm_inv. - simpl. -Local Transparent perm_inv. - rewrite length_insertion_sort_list by auto. - easy. + intros j k Hk. + unfold compose in *. + now rewrite H. Qed. -Lemma permutation_grow_l n f : permutation (S n) f -> - exists g k, k < S n /\ perm_eq (S n) f (swap_perm k n (S n) ∘ g)%prg /\ permutation n g. +Lemma perm_eq_compose_rewrite_l_to_Id {n} {f g : nat -> nat} + (H : perm_eq n (f ∘ g) idn) : forall (h : nat -> nat), + perm_eq n (h ∘ f ∘ g) h. Proof. - intros Hperm. - eexists. - exists (f n). - split; [apply permutation_is_bounded; [easy | lia] | split]. - pose proof (perm_of_insertion_sort_list_of_perm_inv_eq _ _ Hperm) as H. - rewrite perm_of_insertion_sort_list_S in H. - rewrite perm_inv_perm_inv in H by (easy || lia). - exact H. - auto with perm_db. + intros h k Hk. + unfold compose in *. + now rewrite H. Qed. -Lemma permutation_grow_r n f : permutation (S n) f -> - exists g k, k < S n /\ perm_eq (S n) f (g ∘ swap_perm k n (S n))%prg /\ permutation n g. +Lemma perm_eq_compose_rewrite_r {n} {f g h : nat -> nat} + (H : perm_eq n (f ∘ g) h) : forall (i : nat -> nat), + perm_bounded n i -> + perm_eq n (f ∘ (g ∘ i)) (h ∘ i). Proof. - intros Hperm. - eexists. - exists (perm_inv (S n) f n). - split; [apply permutation_is_bounded; [auto with perm_db | lia] | split]. - pose proof (invperm_of_insertion_sort_list_eq _ _ Hperm) as H. - rewrite invperm_of_insertion_sort_list_S in H. - exact H. - auto with perm_db. + intros i Hi k Hk. + unfold compose in *. + now rewrite H by auto. Qed. -(* Section on stack_perms *) -Ltac replace_bool_lia b0 b1 := - first [ - replace b0 with b1 by (bdestruct b0; lia || (destruct b1 eqn:?; lia)) | - replace b0 with b1 by (bdestruct b1; lia || (destruct b0 eqn:?; lia)) | - replace b0 with b1 by (bdestruct b0; bdestruct b1; lia) - ]. - -Lemma stack_perms_left {n0 n1} {f g} {k} : - k < n0 -> stack_perms n0 n1 f g k = f k. +Lemma perm_eq_compose_rewrite_r_to_2 {n} {f g h i : nat -> nat} + (H : perm_eq n (f ∘ g) (h ∘ i)) : forall (j : nat -> nat), + perm_bounded n j -> + perm_eq n (f ∘ (g ∘ j)) (h ∘ (i ∘ j)). Proof. - intros Hk. - unfold stack_perms. - replace_bool_lia (k stack_perms n0 n1 f g k = g (k - n0) + n0. +Lemma perm_eq_compose_rewrite_r_to_Id {n} {f g : nat -> nat} + (H : perm_eq n (f ∘ g) idn) : forall (h : nat -> nat), + perm_bounded n h -> + perm_eq n (f ∘ (g ∘ h)) h. Proof. - intros Hk. - unfold stack_perms. - replace_bool_lia (k stack_perms n0 n1 f g (k + n0) = g k + n0. -Proof. - intros Hk. - rewrite stack_perms_right; [|lia]. - replace (k + n0 - n0) with k by lia. - easy. -Qed. +End PermComposeLemmas. -Lemma stack_perms_add_right {n0 n1} {f g} {k} : - k < n1 -> stack_perms n0 n1 f g (n0 + k) = g k + n0. -Proof. - rewrite Nat.add_comm. - exact stack_perms_right_add. -Qed. -Lemma stack_perms_high {n0 n1} {f g} {k} : - n0 + n1 <= k -> (stack_perms n0 n1 f g) k = k. -Proof. - intros H. - unfold stack_perms. - replace_bool_lia (k if k + constr:(perm_eq_compose_rewrite_l_to_Id lem) + | perm_eq ?n (?F ∘ ?G)%prg (?F' ∘ ?G')%prg => + constr:(perm_eq_compose_rewrite_l_to_2 lem) + | perm_eq ?n (?F ∘ ?G)%prg ?H => + constr:(perm_eq_compose_rewrite_l lem) + | forall a : ?A, @?f a => + let x := fresh a in + constr:(fun x : A => ltac:( + let r := make_perm_eq_compose_assoc_rewrite_l (lem x) in + exact r)) + end. -Lemma stack_perms_idn_f n0 n1 f : - stack_perms n0 n1 idn f = - fun k => if (¬ k + constr:(perm_eq_compose_rewrite_l_to_Id (perm_eq_sym lem)) + | perm_eq ?n (?F ∘ ?G)%prg (?F' ∘ ?G')%prg => + constr:(perm_eq_compose_rewrite_l_to_2 (perm_eq_sym lem)) + | perm_eq ?n ?H (?F ∘ ?G)%prg => + constr:(perm_eq_compose_rewrite_l (perm_eq_sym lem)) + | forall a : ?A, @?f a => + let x := fresh a in + constr:(fun x : A => ltac:( + let r := make_perm_eq_compose_assoc_rewrite_l' (lem x) in + exact r)) + end. -Lemma stack_perms_idn_idn n0 n1 : - stack_perms n0 n1 idn idn = idn. -Proof. solve_modular_permutation_equalities. Qed. +Ltac rewrite_perm_eq_compose_assoc_l lem := + let lem' := make_perm_eq_compose_assoc_rewrite_l lem in + rewrite lem' || rewrite lem. + +Ltac rewrite_perm_eq_compose_assoc_l' lem := + let lem' := make_perm_eq_compose_assoc_rewrite_l' lem in + rewrite lem' || rewrite <- lem. + +Ltac make_perm_eq_compose_assoc_rewrite_r lem := + lazymatch type of lem with + | perm_eq ?n (?F ∘ ?G)%prg idn => + constr:(perm_eq_compose_rewrite_r_to_Id lem) + | perm_eq ?n (?F ∘ ?G)%prg (?F' ∘ ?G')%prg => + constr:(perm_eq_compose_rewrite_r_to_2 lem) + | perm_eq ?n (?F ∘ ?G)%prg ?H => + constr:(perm_eq_compose_rewrite_r lem) + | forall a : ?A, @?f a => + let x := fresh a in + constr:(fun x : A => ltac:( + let r := make_perm_eq_compose_assoc_rewrite_r (lem x) in + exact r)) + end. -#[export] Hint Rewrite stack_perms_idn_idn : perm_cleanup_db. +Ltac make_perm_eq_compose_assoc_rewrite_r' lem := + lazymatch type of lem with + | perm_eq ?n idn (?F ∘ ?G)%prg => + constr:(perm_eq_compose_rewrite_r_to_Id (perm_eq_sym lem)) + | perm_eq ?n (?F ∘ ?G)%prg (?F' ∘ ?G')%prg => + constr:(perm_eq_compose_rewrite_r_to_2 (perm_eq_sym lem)) + | perm_eq ?n ?H (?F ∘ ?G)%prg => + constr:(perm_eq_compose_rewrite_r (perm_eq_sym lem)) + | forall a : ?A, @?f a => + let x := fresh a in + constr:(fun x : A => ltac:( + let r := make_perm_eq_compose_assoc_rewrite_r' (lem x) in + exact r)) + end. -Lemma stack_perms_compose {n0 n1} {f g} {f' g'} - (Hf' : permutation n0 f') (Hg' : permutation n1 g') : - (stack_perms n0 n1 f g ∘ stack_perms n0 n1 f' g' - = stack_perms n0 n1 (f ∘ f') (g ∘ g'))%prg. -Proof. - destruct Hf' as [Hf'inv Hf']. - destruct Hg' as [Hg'inv Hg']. - unfold compose. - (* bdestruct_one. *) - solve_modular_permutation_equalities. - 1,2: specialize (Hf' k H); lia. - - f_equal; f_equal. lia. - - assert (Hk: k - n0 < n1) by lia. - specialize (Hg' _ Hk); lia. -Qed. +Ltac rewrite_perm_eq_compose_assoc_r lem := + let lem' := make_perm_eq_compose_assoc_rewrite_r lem in + rewrite lem' || rewrite lem. -Lemma stack_perms_assoc {n0 n1 n2} {f g h} : - stack_perms (n0 + n1) n2 (stack_perms n0 n1 f g) h = - stack_perms n0 (n1 + n2) f (stack_perms n1 n2 g h). -Proof. - apply functional_extensionality; intros k. - unfold stack_perms. - bdestructΩ'. - rewrite (Nat.add_comm n0 n1), Nat.add_assoc. - f_equal; f_equal; f_equal. - lia. -Qed. +Ltac rewrite_perm_eq_compose_assoc_r' lem := + let lem' := make_perm_eq_compose_assoc_rewrite_r' lem in + rewrite lem' || rewrite <- lem. -Lemma stack_perms_idn_of_left_right_idn {n0 n1} {f g} - (Hf : forall k, k < n0 -> f k = k) (Hg : forall k, k < n1 -> g k = k) : - stack_perms n0 n1 f g = idn. -Proof. - solve_modular_permutation_equalities. - - apply Hf; easy. - - rewrite Hg; lia. -Qed. +Notation "'###perm_l' '->' lem" := + (ltac:(let r := make_perm_eq_compose_assoc_rewrite_l lem in exact r)) + (at level 0, lem at level 15, only parsing). -(* Section on rotr / rotl *) -Lemma rotr_WF : - forall n k, WF_Perm n (rotr n k). -Proof. unfold WF_Perm. intros. unfold rotr. bdestruct_one; lia. Qed. +Notation "'###perm_r' '->' lem" := + (ltac:(let r := make_perm_eq_compose_assoc_rewrite_r lem in exact r)) + (at level 0, lem at level 15, only parsing). -Lemma rotl_WF {n m} : - forall k, n <= k -> (rotl n m) k = k. -Proof. intros. unfold rotl. bdestruct_one; lia. Qed. +Notation "'###perm_l' '<-' lem" := + (ltac:(let r := make_perm_eq_compose_assoc_rewrite_l' lem in exact r)) + (at level 0, lem at level 15, only parsing). -#[export] Hint Resolve rotr_WF rotl_WF : WF_perm_db. +Notation "'###perm_r' '<-' lem" := + (ltac:(let r := make_perm_eq_compose_assoc_rewrite_r' lem in exact r)) + (at level 0, lem at level 15, only parsing). -Lemma rotr_bounded {n m} : - forall k, k < n -> (rotr n m) k < n. -Proof. - intros. unfold rotr. bdestruct_one; [lia|]. - apply Nat.mod_upper_bound; lia. -Qed. -Lemma rotl_bounded {n m} : - forall k, k < n -> (rotl n m) k < n. -Proof. - intros. unfold rotl. bdestruct_one; [lia|]. - apply Nat.mod_upper_bound; lia. -Qed. +Section ComposeLemmas. -#[export] Hint Resolve rotr_bounded rotl_bounded : perm_bounded_db. +Local Open Scope prg. -Lemma rotr_rotl_inv n m : - ((rotr n m) ∘ (rotl n m) = idn)%prg. +(* Helpers for rewriting with compose and perm_eq *) +Lemma compose_rewrite_l {f g h : nat -> nat} + (H : f ∘ g = h) : forall (i : nat -> nat), + i ∘ f ∘ g = i ∘ h. Proof. - apply functional_extensionality; intros k. - unfold compose, rotl, rotr. - bdestruct (n <=? k); [bdestructΩ'|]. - assert (Hn0 : n <> 0) by lia. - bdestruct_one. - - pose proof (Nat.mod_upper_bound (k + (n - m mod n)) n Hn0) as Hbad. - lia. (* contradict Hbad *) - - rewrite Nat.add_mod_idemp_l; [|easy]. - rewrite <- Nat.add_assoc. - replace (n - m mod n + m) with - (n - m mod n + (n * (m / n) + m mod n)) by - (rewrite <- (Nat.div_mod m n Hn0); easy). - pose proof (Nat.mod_upper_bound m n Hn0). - replace (n - m mod n + (n * (m / n) + m mod n)) with - (n * (1 + m / n)) by lia. - rewrite Nat.mul_comm, Nat.mod_add; [|easy]. - apply Nat.mod_small, H. + intros; + now rewrite compose_assoc, H. Qed. -Lemma rotl_rotr_inv n m : - ((rotl n m) ∘ (rotr n m) = idn)%prg. +Lemma compose_rewrite_l_to_2 {f g h i : nat -> nat} + (H : f ∘ g = h ∘ i) : forall (j : nat -> nat), + j ∘ f ∘ g = j ∘ h ∘ i. Proof. - apply functional_extensionality; intros k. - unfold compose, rotl, rotr. - bdestruct (n <=? k); [bdestructΩ'|]. - assert (Hn0 : n <> 0) by lia. - bdestruct_one. - - pose proof (Nat.mod_upper_bound (k + m) n Hn0) as Hbad. - lia. (* contradict Hbad *) - - rewrite Nat.add_mod_idemp_l; [|easy]. - rewrite <- Nat.add_assoc. - replace (m + (n - m mod n)) with - ((n * (m / n) + m mod n) + (n - m mod n)) by - (rewrite <- (Nat.div_mod m n Hn0); easy). - pose proof (Nat.mod_upper_bound m n Hn0). - replace ((n * (m / n) + m mod n) + (n - m mod n)) with - (n * (1 + m / n)) by lia. - rewrite Nat.mul_comm, Nat.mod_add; [|easy]. - apply Nat.mod_small, H. -Qed. - -#[export] Hint Rewrite rotr_rotl_inv rotl_rotr_inv : perm_inv_db. - -Lemma rotr_perm {n m} : permutation n (rotr n m). -Proof. - perm_by_inverse (rotl n m). + intros; + now rewrite !compose_assoc, H. Qed. -Lemma rotl_perm {n m} : permutation n (rotl n m). -Proof. - perm_by_inverse (rotr n m). -Qed. - -#[export] Hint Resolve rotr_perm rotl_perm : perm_db. - -Lemma rotr_0_r n : rotr n 0 = idn. +Lemma compose_rewrite_l_to_Id {f g : nat -> nat} + (H : f ∘ g = idn) : forall (h : nat -> nat), + h ∘ f ∘ g = h. Proof. - apply functional_extensionality; intros k. - unfold rotr. - bdestructΩ'. - rewrite Nat.mod_small; lia. + intros; + now rewrite compose_assoc, H, compose_idn_r. Qed. -Lemma rotl_0_r n : rotl n 0 = idn. +Lemma compose_rewrite_r {f g h : nat -> nat} + (H : f ∘ g = h) : forall (i : nat -> nat), + f ∘ (g ∘ i) = h ∘ i. Proof. - apply functional_extensionality; intros k. - unfold rotl. - bdestructΩ'. - rewrite Nat.mod_0_l, Nat.sub_0_r; [|lia]. - replace (k + n) with (k + 1 * n) by lia. - rewrite Nat.mod_add, Nat.mod_small; lia. + intros; + now rewrite <- compose_assoc, H. Qed. -Lemma rotr_0_l k : rotr 0 k = idn. +Lemma compose_rewrite_r_to_2 {f g h i : nat -> nat} + (H : f ∘ g = h ∘ i) : forall (j : nat -> nat), + f ∘ (g ∘ j) = h ∘ (i ∘ j). Proof. - apply functional_extensionality; intros a. - unfold rotr. - bdestructΩ'. + intros; + now rewrite <- !compose_assoc, H. Qed. - -Lemma rotl_0_l k : rotl 0 k = idn. -Proof. - apply functional_extensionality; intros a. - unfold rotl. - bdestructΩ'. -Qed. - -#[export] Hint Rewrite rotr_0_r rotl_0_r rotr_0_l rotl_0_l : perm_cleanup_db. -Lemma rotr_rotr n k l : - ((rotr n k) ∘ (rotr n l) = rotr n (k + l))%prg. +Lemma compose_rewrite_r_to_Id {f g : nat -> nat} + (H : f ∘ g = idn) : forall (h : nat -> nat), + f ∘ (g ∘ h) = h. Proof. - apply functional_extensionality; intros a. - unfold compose, rotr. - symmetry. - bdestructΩ'; assert (Hn0 : n <> 0) by lia. - - pose proof (Nat.mod_upper_bound (a + l) n Hn0); lia. - - rewrite Nat.add_mod_idemp_l; [|easy]. - f_equal; lia. + intros; + now rewrite <- compose_assoc, H, compose_idn_l. Qed. -Lemma rotl_rotl n k l : - ((rotl n k) ∘ (rotl n l) = rotl n (k + l))%prg. -Proof. - apply (WF_permutation_inverse_injective (rotr n (k + l)) n). - - apply rotr_perm. - - apply rotr_WF. - - rewrite Nat.add_comm, <- rotr_rotr, - <- Combinators.compose_assoc, (Combinators.compose_assoc _ _ _ _ (rotr n l)). - cleanup_perm; easy. (* rewrite rotl_rotr_inv, compose_idn_r, rotl_rotr_inv. *) - - rewrite rotl_rotr_inv; easy. -Qed. +End ComposeLemmas. -#[export] Hint Rewrite rotr_rotr rotl_rotl : perm_cleanup_db. +Ltac make_compose_assoc_rewrite_l lem := + lazymatch type of lem with + | forall a : ?A, @?f a => + let x := fresh a in + constr:(fun x : A => ltac:( + let r := make_compose_assoc_rewrite_l (lem x) in + exact r)) + | (?F ∘ ?G)%prg = idn => + constr:(compose_rewrite_l_to_Id lem) + | (?F ∘ ?G)%prg = (?F' ∘ ?G')%prg => + constr:(compose_rewrite_l_to_2 lem) + | (?F ∘ ?G)%prg = ?H => + constr:(compose_rewrite_l lem) + end. -Lemma rotr_n n : rotr n n = idn. -Proof. - apply functional_extensionality; intros a. - unfold rotr. - bdestructΩ'. - replace (a + n) with (a + 1 * n) by lia. - destruct n; [lia|]. - rewrite Nat.mod_add; [|easy]. - rewrite Nat.mod_small; easy. -Qed. +Ltac make_compose_assoc_rewrite_l' lem := + lazymatch type of lem with + | forall a : ?A, @?f a => + let x := fresh a in + constr:(fun x : A => ltac:( + let r := make_compose_assoc_rewrite_l' (lem x) in + exact r)) + | idn = (?F ∘ ?G)%prg => + constr:(compose_rewrite_l_to_Id (eq_sym lem)) + | (?F ∘ ?G)%prg = (?F' ∘ ?G')%prg => + constr:(compose_rewrite_l_to_2 (eq_sym lem)) + | ?H = (?F ∘ ?G)%prg => + constr:(compose_rewrite_l (eq_sym lem)) + end. -#[export] Hint Rewrite rotr_n : perm_cleanup_db. +Ltac rewrite_compose_assoc_l lem := + let lem' := make_compose_assoc_rewrite_l lem in + rewrite lem' || rewrite lem. + +Ltac rewrite_compose_assoc_l' lem := + let lem' := make_compose_assoc_rewrite_l' lem in + rewrite lem' || rewrite <- lem. + +Ltac make_compose_assoc_rewrite_r lem := + lazymatch type of lem with + | forall a : ?A, @?f a => + let x := fresh a in + constr:(fun x : A => ltac:( + let r := make_compose_assoc_rewrite_r (lem x) in + exact r)) + | (?F ∘ ?G)%prg = idn => + constr:(compose_rewrite_r_to_Id lem) + | (?F ∘ ?G)%prg = (?F' ∘ ?G')%prg => + constr:(compose_rewrite_r_to_2 lem) + | (?F ∘ ?G)%prg = ?H => + constr:(compose_rewrite_r lem) + end. -Ltac perm_eq_by_WF_inv_inj f n := - let tryeasylia := try easy; try lia in - apply (WF_permutation_inverse_injective f n); [ - tryeasylia; auto with perm_db | - tryeasylia; auto with WF_perm_db | - try solve [cleanup_perm; auto] | - try solve [cleanup_perm; auto]]; tryeasylia. +Ltac make_compose_assoc_rewrite_r' lem := + lazymatch type of lem with + | forall a : ?A, @?f a => + let x := fresh a in + constr:(fun x : A => ltac:( + let r := make_compose_assoc_rewrite_r' (lem x) in + exact r)) + | idn = (?F ∘ ?G)%prg => + constr:(compose_rewrite_r_to_Id (eq_sym lem)) + | (?F ∘ ?G)%prg = (?F' ∘ ?G')%prg => + constr:(compose_rewrite_r_to_2 (eq_sym lem)) + | ?H = (?F ∘ ?G)%prg => + constr:(compose_rewrite_r (eq_sym lem)) + end. -Lemma rotr_eq_rotr_mod n k : rotr n k = rotr n (k mod n). -Proof. - perm_eq_by_WF_inv_inj (rotl n k) n. - - unfold WF_Perm. - apply rotl_WF. - - apply functional_extensionality. - intros a. - unfold Basics.compose, rotr. - bdestruct (n <=? a). - + rewrite (rotl_WF a) by easy. - bdestruct_all; easy. - + unfold rotl. - pose proof (Nat.mod_upper_bound (a + (n-k mod n)) n ltac:(lia)). - bdestruct_all; try lia. - rewrite <- Nat.add_mod by lia. - rewrite (Nat.div_mod k n) at 2 by lia. - replace ((a + (n - k mod n) + (n * (k / n) + k mod n))) - with ((a + ((n - k mod n) + k mod n + (n * (k / n))))) by lia. - rewrite Nat.sub_add by (pose proof (Nat.mod_upper_bound k n); lia). - rewrite Nat.add_assoc, Nat.mul_comm, Nat.mod_add by lia. - rewrite mod_add_n_r, Nat.mod_small; lia. -Qed. +Ltac rewrite_compose_assoc_r lem := + let lem' := make_compose_assoc_rewrite_r lem in + rewrite lem' || rewrite lem. -Lemma rotl_n n : rotl n n = idn. -Proof. - perm_eq_by_WF_inv_inj (rotr n n) n. -Qed. +Ltac rewrite_compose_assoc_r' lem := + let lem' := make_compose_assoc_rewrite_r' lem in + rewrite lem' || rewrite <- lem. -#[export] Hint Rewrite rotl_n : perm_cleanup_db. +Notation "'###comp_l' '->' lem" := + (ltac:(let r := make_compose_assoc_rewrite_l lem in exact r)) + (at level 0, lem at level 15, only parsing). -Lemma rotl_eq_rotl_mod n k : rotl n k = rotl n (k mod n). -Proof. - perm_eq_by_WF_inv_inj (rotr n k) n. - rewrite rotr_eq_rotr_mod, rotl_rotr_inv; easy. -Qed. +Notation "'###comp_r' '->' lem" := + (ltac:(let r := make_compose_assoc_rewrite_r lem in exact r)) + (at level 0, lem at level 15, only parsing). -Lemma rotr_eq_rotl_sub n k : - rotr n k = rotl n (n - k mod n). -Proof. - rewrite rotr_eq_rotr_mod. - perm_eq_by_WF_inv_inj (rotl n (k mod n)) n. - - unfold WF_Perm. - apply rotr_WF. - - cleanup_perm. - destruct n; [rewrite rotl_0_l; easy|]. - assert (H': S n <> 0) by easy. - pose proof (Nat.mod_upper_bound k _ H'). - rewrite <- (rotl_n (S n)). - f_equal. - lia. -Qed. +Notation "'###comp_l' '<-' lem" := + (ltac:(let r := make_compose_assoc_rewrite_l' lem in exact r)) + (at level 0, lem at level 15, only parsing). -Lemma rotl_eq_rotr_sub n k : - rotl n k = rotr n (n - k mod n). -Proof. - perm_eq_by_WF_inv_inj (rotr n k) n. - destruct n; [cbn; rewrite 2!rotr_0_l, compose_idn_l; easy|]. - rewrite (rotr_eq_rotr_mod _ k), rotr_rotr, <- (rotr_n (S n)). - f_equal. - assert (H' : S n <> 0) by easy. - pose proof (Nat.mod_upper_bound k (S n) H'). - lia. -Qed. +Notation "'###comp_r' '<-' lem" := + (ltac:(let r := make_compose_assoc_rewrite_r' lem in exact r)) + (at level 0, lem at level 15, only parsing). \ No newline at end of file diff --git a/PermutationInstances.v b/PermutationInstances.v new file mode 100644 index 0000000..811a97e --- /dev/null +++ b/PermutationInstances.v @@ -0,0 +1,4092 @@ +Require Import Modulus. +Require Export PermutationsBase. +Require Import PermutationAutomation. +Require Export Prelim. +Require Export Bits. + +Import Setoid. + +(* Definitions of particular permutations, operations on permutations, + and their interactions *) +Local Open Scope program_scope. +Local Open Scope nat_scope. + +Definition stack_perms (n0 n1 : nat) (f g : nat -> nat) : nat -> nat := + fun n => + if (n nat) : nat -> nat := + fun n => if (n0 * n1 <=? n) then n else + (f (n / n1) * n1 + g (n mod n1)). + +Definition swap_perm a b n := + fun k => if n <=? k then k else + if k =? a then b else + if k =? b then a else k. + + +(* TODO: Implement things for this *) +Fixpoint insertion_sort_list n f := + match n with + | 0 => [] + | S n' => let k := (perm_inv (S n') f n') in + k :: insertion_sort_list n' (Bits.fswap f k n') + end. + +Fixpoint swap_list_spec l : bool := + match l with + | [] => true + | k :: ks => (k idn + | k :: ks => let n := length ks in + (swap_perm k n (S n) ∘ (perm_of_swap_list ks))%prg + end. + +Fixpoint invperm_of_swap_list l := + match l with + | [] => idn + | k :: ks => let n := length ks in + ((invperm_of_swap_list ks) ∘ swap_perm k n (S n))%prg + end. + +Definition perm_inv' n f := + fun k => if n <=? k then k else perm_inv n f k. + +Definition contract_perm f a := + fun k => + if k nat := + swap_perm 0 1 2. + +Definition rotr n m : nat -> nat := + fun k => if n <=? k then k else (k + m) mod n. + +Definition rotl n m : nat -> nat := + fun k => if n <=? k then k else (k + (n - (m mod n))) mod n. + +Definition swap_block_perm padl padm a := + fun k => + if k + if k + if n <=? k then k else n - S k. + +(** Given a permutation p over n qubits, construct a permutation over 2^n indices. *) +Definition qubit_perm_to_nat_perm n (p : nat -> nat) := + fun k => + if 2 ^ n <=? k then k else + funbool_to_nat n ((nat_to_funbool n k) ∘ p)%prg. + +Definition kron_comm_perm p q := + fun k => if p * q <=? k then k else + k mod p * q + k / p. + +Definition perm_eq_id_mid (padl padm : nat) (f : nat -> nat) : Prop := + forall a, a < padm -> f (padl + a) = padl + a. + +Definition expand_perm_id_mid (padl padm padr : nat) + (f : nat -> nat) : nat -> nat := + stack_perms padl (padm + padr) idn (rotr (padm + padr) padm) + ∘ (stack_perms (padl + padr) padm f idn) + ∘ stack_perms padl (padm + padr) idn (rotr (padm + padr) padr). + +Definition contract_perm_id_mid (padl padm padr : nat) + (f : nat -> nat) : nat -> nat := + stack_perms padl (padm + padr) idn (rotr (padm + padr) padr) ∘ + f ∘ stack_perms padl (padm + padr) idn (rotr (padm + padr) padm). + +#[export] Hint Unfold + stack_perms compose + rotr rotl + swap_2_perm swap_perm : perm_unfold_db. + + + + +Lemma permutation_change_dims n m (H : n = m) f : + permutation n f <-> permutation m f. +Proof. + now subst. +Qed. + +Lemma perm_bounded_change_dims n m (Hnm : n = m) f (Hf : perm_bounded m f) : + perm_bounded n f. +Proof. + now subst. +Qed. + +Lemma perm_eq_dim_change_if_nonzero n m f g : + perm_eq m f g -> (n <> 0 -> n = m) -> perm_eq n f g. +Proof. + intros Hfg H k Hk. + rewrite H in Hk by lia. + now apply Hfg. +Qed. + +Lemma perm_eq_dim_change n m f g : + perm_eq m f g -> n = m -> perm_eq n f g. +Proof. + intros. + now apply (perm_eq_dim_change_if_nonzero n m f g). +Qed. + +Lemma permutation_defn n f : + permutation n f <-> exists g, + (perm_bounded n f) /\ (perm_bounded n g) /\ + (perm_eq n (f ∘ g) idn) /\ (perm_eq n (g ∘ f) idn). +Proof. + split; intros [g Hg]; exists g. + - repeat split; hnf; intros; now apply Hg. + - intros; repeat split; now apply Hg. +Qed. + +Lemma permutation_of_le_permutation_idn_above n m f : + permutation n f -> m <= n -> (forall k, m <= k < n -> f k = k) -> + permutation m f. +Proof. + intros Hf Hm Hfid. + pose proof Hf as Hf'. + destruct Hf' as [finv Hfinv]. + exists finv. + intros k Hk; repeat split; try (apply Hfinv; lia). + - pose proof (Hfinv k ltac:(lia)) as (?&?&?&?). + bdestructΩ (f k (f (finv k)) in Hfid. + lia. +Qed. + +Add Parametric Morphism n : (permutation n) + with signature perm_eq n ==> iff as permutation_perm_eq_proper. +Proof. + intros f g Hfg. + split; intros [inv Hinv]; + exists inv; + intros k Hk; + [rewrite <- 2!Hfg by (destruct (Hinv k Hk); easy) | + rewrite 2!Hfg by (destruct (Hinv k Hk); easy)]; + apply Hinv, Hk. +Qed. + +Lemma permutation_eqb_iff {n f} a b : permutation n f -> + a < n -> b < n -> + f a =? f b = (a =? b). +Proof. + intros Hperm Hk Hfk. + bdestruct_one. + apply (permutation_is_injective n f Hperm) in H; [bdestruct_one| |]; lia. + bdestruct_one; subst; easy. +Qed. + +Lemma permutation_eq_iff {n f} a b : permutation n f -> + a < n -> b < n -> + f a = f b <-> a = b. +Proof. + intros Hperm Hk Hfk. + generalize (permutation_eqb_iff _ _ Hperm Hk Hfk). + bdestructΩ'. +Qed. + +Lemma perm_eq_iff_forall n (f g : nat -> nat) : + perm_eq n f g <-> forallb (fun k => f k =? g k) (seq 0 n) = true. +Proof. + rewrite forallb_seq0. + now setoid_rewrite Nat.eqb_eq. +Qed. + +Lemma perm_eq_dec n (f g : nat -> nat) : + {perm_eq n f g} + {~ perm_eq n f g}. +Proof. + generalize (perm_eq_iff_forall n f g). + destruct (forallb (fun k => f k =? g k) (seq 0 n)); intros H; + [left | right]; rewrite H; easy. +Qed. + +Lemma not_forallb_seq_exists f start len : + forallb f (seq start len) = false -> + exists n, n < len /\ f (n + start) = false. +Proof. + revert start; induction len; [easy|]. + intros start. + simpl. + rewrite andb_false_iff. + intros [H | H]. + - exists 0. split; [lia | easy]. + - destruct (IHlen (S start) H) as (n & Hn & Hfn). + exists (S n); split; rewrite <- ?Hfn; f_equal; lia. +Qed. + +Lemma not_forallb_seq0_exists f n : + forallb f (seq 0 n) = false -> + exists k, k < n /\ f k = false. +Proof. + intros H. + apply not_forallb_seq_exists in H. + setoid_rewrite Nat.add_0_r in H. + exact H. +Qed. + +Lemma not_perm_eq_not_eq_at n (f g : nat -> nat) : + ~ (perm_eq n f g) -> exists k, k < n /\ f k <> g k. +Proof. + rewrite perm_eq_iff_forall. + rewrite not_true_iff_false. + intros H. + apply not_forallb_seq0_exists in H. + setoid_rewrite Nat.eqb_neq in H. + exact H. +Qed. + +Lemma perm_bounded_of_eq {n f g} : + perm_eq n g f -> perm_bounded n f -> + perm_bounded n g. +Proof. + intros Hfg Hf k Hk. + rewrite Hfg; auto. +Qed. + +Lemma compose_perm_bounded n f g : perm_bounded n f -> perm_bounded n g -> + perm_bounded n (f ∘ g). +Proof. + unfold compose. + auto. +Qed. + +#[export] Hint Resolve compose_perm_bounded : perm_bounded_db. + +(* Section on perm_inv *) +Lemma perm_inv_linv_of_permutation n f (Hf : permutation n f) : + perm_eq n (perm_inv n f ∘ f) idn. +Proof. + exact (perm_inv_is_linv_of_permutation n f Hf). +Qed. + +Lemma perm_inv_rinv_of_permutation n f (Hf : permutation n f) : + perm_eq n (f ∘ perm_inv n f) idn. +Proof. + exact (perm_inv_is_rinv_of_permutation n f Hf). +Qed. + +#[export] Hint Rewrite + perm_inv_linv_of_permutation + perm_inv_rinv_of_permutation + using (solve [auto with perm_db]) : perm_inv_db. + +Lemma perm_inv'_eq n f : + perm_eq n (perm_inv' n f) (perm_inv n f). +Proof. + intros k Hk. + unfold perm_inv'. + bdestructΩ'. +Qed. + +#[export] Hint Extern 0 + (perm_eq ?n (perm_inv' ?n ?f) ?g) => + apply (perm_eq_trans (perm_inv'_eq n _)) : perm_inv_db. + +#[export] Hint Extern 0 + (perm_eq ?n ?g (perm_inv' ?n ?f)) => + apply (fun H => perm_eq_trans + H (perm_eq_sym (perm_inv'_eq n _))) : perm_inv_db. + +#[export] Hint Rewrite perm_inv'_eq : perm_inv_db. + +Lemma perm_inv'_bounded n f : + perm_bounded n (perm_inv' n f). +Proof. + apply (perm_bounded_of_eq (perm_inv'_eq n f)). + auto with perm_bounded_db. +Qed. + +Lemma perm_inv'_WF n f : + WF_Perm n (perm_inv' n f). +Proof. + intros k Hk; + unfold perm_inv'; + bdestructΩ'. +Qed. + +#[export] Hint Resolve perm_inv'_bounded : perm_bounded_db. +#[export] Hint Resolve perm_inv'_WF : WF_Perm_db. + +Lemma perm_inv'_permutation n f : permutation n f -> + permutation n (perm_inv' n f). +Proof. + cleanup_perm_inv. +Qed. + +#[export] Hint Resolve perm_inv'_permutation : perm_db. + +Lemma permutation_of_le_permutation_WF f m n : (m <= n)%nat -> permutation m f -> + WF_Perm m f -> permutation n f. +Proof. + intros Hmn [finv_m Hfinv_m] HWF. + exists (fun k => if m <=? k then k else finv_m k). + intros k Hk. + bdestruct (m <=? k). + - rewrite !HWF; bdestructΩ'. + - specialize (Hfinv_m _ H). + bdestructΩ'. +Qed. + +Lemma perm_eq_compose_proper n (f f' g g' : nat -> nat) : + perm_bounded n g -> perm_eq n f f' -> perm_eq n g g' -> + perm_eq n (f ∘ g) (f' ∘ g'). +Proof. + intros Hg Hf' Hg' k Hk. + unfold compose. + now rewrite Hf', Hg' by auto. +Qed. + +#[export] Hint Resolve perm_eq_compose_proper : perm_inv_db. + +Add Parametric Morphism n f : (@compose nat nat nat f) with signature + perm_eq n ==> perm_eq n as compose_perm_eq_proper_r. +Proof. + intros g g' Hg k Hk. + unfold compose. + now rewrite Hg. +Qed. + +Add Parametric Morphism n : (@compose nat nat nat) with signature + perm_eq n ==> + on_predicate_relation_l (fun f => perm_bounded n f) (perm_eq n) ==> + perm_eq n as compose_perm_eq_proper_l. +Proof. + intros f f' Hf g g' [Hgbdd Hg] k Hk. + unfold compose. + rewrite <- Hg by easy. + auto. +Qed. + +Lemma perm_inv_is_linv_of_permutation_compose (n : nat) (f : nat -> nat) : + permutation n f -> + perm_eq n (perm_inv n f ∘ f) idn. +Proof. + exact (perm_inv_is_linv_of_permutation n f). +Qed. + +#[export] Hint Resolve + perm_inv_is_linv_of_permutation + perm_inv_is_linv_of_permutation_compose : perm_inv_db. + +Lemma perm_inv_is_rinv_of_permutation_compose (n : nat) (f : nat -> nat) : + permutation n f -> + perm_eq n (f ∘ perm_inv n f) idn. +Proof. + exact (perm_inv_is_rinv_of_permutation n f). +Qed. + +#[export] Hint Resolve + perm_inv_is_rinv_of_permutation + perm_inv_is_rinv_of_permutation_compose : perm_inv_db. + +#[export] Hint Rewrite perm_inv_is_linv_of_permutation_compose + perm_inv_is_rinv_of_permutation_compose + using solve [auto with perm_db] : perm_inv_db. + +#[export] Hint Rewrite + ###perm_l -> perm_inv_is_linv_of_permutation_compose + using solve [auto with perm_bounded_db perm_db] : perm_inv_db. + +#[export] Hint Rewrite + ###perm_r -> perm_inv_is_linv_of_permutation_compose + using solve [auto with perm_bounded_db perm_db] : perm_inv_db. + +#[export] Hint Rewrite + ###perm_l -> perm_inv_is_rinv_of_permutation_compose + using solve [auto with perm_bounded_db perm_db] : perm_inv_db. + +#[export] Hint Rewrite + ###perm_r -> perm_inv_is_rinv_of_permutation_compose + using solve [auto with perm_bounded_db perm_db] : perm_inv_db. + +Lemma perm_inv'_is_linv_of_permutation_compose (n : nat) (f : nat -> nat) : + permutation n f -> + perm_eq n (perm_inv' n f ∘ f) idn. +Proof. + intros. + cleanup_perm_inv. +Qed. + +#[export] Hint Resolve perm_inv'_is_linv_of_permutation_compose : perm_inv_db. + +Lemma perm_inv'_is_rinv_of_permutation_compose (n : nat) (f : nat -> nat) : + permutation n f -> + perm_eq n (f ∘ perm_inv' n f) idn. +Proof. + intros. + cleanup_perm_inv. +Qed. + +#[export] Hint Resolve perm_inv'_is_rinv_of_permutation_compose : perm_inv_db. + +Lemma permutation_iff_perm_inv'_inv n f : + permutation n f <-> + perm_eq n (f ∘ perm_inv' n f) idn /\ + perm_eq n (perm_inv' n f ∘ f) idn. +Proof. + split; [auto_perm|]. + intros [Hrinv Hlinv]. + assert (Hfbdd : perm_bounded n f). { + intros k Hk. + generalize (Hlinv k Hk). + unfold compose, perm_inv'. + bdestructΩ'. + } + exists (perm_inv' n f). + intros k Hk. + repeat split; [auto_perm.. | |]. + - now apply Hlinv. + - now apply Hrinv. +Qed. + + +Lemma idn_WF_Perm n : WF_Perm n idn. +Proof. easy. Qed. + +#[export] Hint Resolve idn_WF_Perm : WF_Perm_db. + + +Lemma perm_inv'_linv_of_permutation_WF n f : + permutation n f -> WF_Perm n f -> + perm_inv' n f ∘ f = idn. +Proof. + intros. + eq_by_WF_perm_eq n. + cleanup_perm_inv. +Qed. + +Lemma perm_inv'_rinv_of_permutation_WF n f : + permutation n f -> WF_Perm n f -> + f ∘ perm_inv' n f = idn. +Proof. + intros. + eq_by_WF_perm_eq n. + cleanup_perm_inv. +Qed. + +#[export] Hint Rewrite perm_inv'_linv_of_permutation_WF + perm_inv'_rinv_of_permutation_WF + using (solve [auto with perm_db WF_Perm_db]) : perm_inv_db. + +#[export] Hint Rewrite + (###comp_l -> perm_inv'_linv_of_permutation_WF) + (###comp_r -> perm_inv'_linv_of_permutation_WF) + (###comp_l -> perm_inv'_rinv_of_permutation_WF) + (###comp_r -> perm_inv'_rinv_of_permutation_WF) + using (solve [auto with perm_db WF_Perm_db]) : perm_inv_db. + + +Lemma perm_eq_linv_injective n f finv finv' : permutation n f -> + is_perm_linv n f finv -> is_perm_linv n f finv' -> + perm_eq n finv finv'. +Proof. + intros Hperm Hfinv Hfinv'. + perm_eq_by_inv_inj f n. +Qed. + +Lemma perm_inv_eq_inv n f finv : + (forall x : nat, x < n -> f x < n /\ finv x < n + /\ finv (f x) = x /\ f (finv x) = x) + -> perm_eq n (perm_inv n f) finv. +Proof. + intros Hfinv. + assert (Hperm: permutation n f) by (exists finv; easy). + perm_eq_by_inv_inj f n. + intros k Hk; now apply Hfinv. +Qed. + +Lemma perm_inv_is_inv n f : permutation n f -> + forall k : nat, k < n -> perm_inv n f k < n /\ f k < n + /\ f (perm_inv n f k) = k /\ perm_inv n f (f k) = k. +Proof. + intros Hperm k Hk. + repeat split. + - apply perm_inv_bounded, Hk. + - destruct Hperm as [? H]; apply H, Hk. + - rewrite perm_inv_is_rinv_of_permutation; easy. + - rewrite perm_inv_is_linv_of_permutation; easy. +Qed. + +Lemma perm_inv_perm_inv n f : permutation n f -> + perm_eq n (perm_inv n (perm_inv n f)) f. +Proof. + intros Hf. + perm_eq_by_inv_inj (perm_inv n f) n. +Qed. + +#[export] Hint Resolve perm_inv_perm_inv : perm_inv_db. +#[export] Hint Rewrite perm_inv_perm_inv + using solve [auto with perm_db] : perm_inv_db. + +Lemma perm_inv_eq_of_perm_eq' n m f g : perm_eq n f g -> m <= n -> + perm_eq n (perm_inv m f) (perm_inv m g). +Proof. + intros Heq Hm. + induction m; [easy|]. + intros k Hk. + simpl. + rewrite Heq by lia. + rewrite IHm by lia. + easy. +Qed. + +Lemma perm_inv_eq_of_perm_eq n f g : perm_eq n f g -> + perm_eq n (perm_inv n f) (perm_inv n g). +Proof. + intros Heq. + apply perm_inv_eq_of_perm_eq'; easy. +Qed. + +#[export] Hint Resolve perm_inv_eq_of_perm_eq : perm_inv_db. + +Lemma perm_inv'_eq_of_perm_eq n f g : perm_eq n f g -> + perm_inv' n f = perm_inv' n g. +Proof. + intros Heq. + eq_by_WF_perm_eq n. + cleanup_perm_inv. +Qed. + +#[export] Hint Resolve perm_inv_eq_of_perm_eq' : perm_inv_db. + +Add Parametric Morphism n : (perm_inv n) with signature + perm_eq n ==> perm_eq n as perm_inv_perm_eq_proper. +Proof. + apply perm_inv_eq_of_perm_eq. +Qed. + +Add Parametric Morphism n : (perm_inv' n) with signature + perm_eq n ==> eq as perm_inv'_perm_eq_to_eq_proper. +Proof. + apply perm_inv'_eq_of_perm_eq. +Qed. + +Add Parametric Morphism n : (perm_inv' n) with signature + perm_eq n ==> perm_eq n as perm_inv'_perm_eq_proper. +Proof. + now intros f g ->. +Qed. + +#[export] Hint Extern 20 + (?f = ?g) => + eapply eq_of_WF_perm_eq; + [solve [auto with WF_Perm_db]..|] : perm_inv_db. + +#[export] Hint Extern 20 + (?f ?k = ?g ?k) => + match goal with + | Hk : k < ?n |- _ => + let Heq := fresh in + enough (Heq : perm_eq n f g) by (exact (Heq k Hk)) + end : perm_inv_db. + +Lemma perm_inv'_perm_inv n f : permutation n f -> + perm_eq n (perm_inv' n (perm_inv n f)) f. +Proof. + cleanup_perm_inv. +Qed. + +Lemma perm_inv_perm_inv' n f : permutation n f -> + perm_eq n (perm_inv n (perm_inv' n f)) f. +Proof. + intros Hf k Hk. + rewrite (perm_inv_eq_of_perm_eq _ _ _ (perm_inv'_eq _ _)) by easy. + cleanup_perm_inv. +Qed. + +Lemma perm_inv'_perm_inv_eq n f : + permutation n f -> WF_Perm n f -> + perm_inv' n (perm_inv n f) = f. +Proof. + intros. + cleanup_perm_inv. +Qed. + +Lemma perm_inv'_perm_inv' n f : permutation n f -> + perm_eq n (perm_inv' n (perm_inv' n f)) f. +Proof. + intros Hf. + rewrite (perm_inv'_eq_of_perm_eq _ _ _ (perm_inv'_eq n f)). + cleanup_perm_inv. +Qed. + +Lemma perm_inv'_perm_inv'_eq n f : + permutation n f -> WF_Perm n f -> + perm_inv' n (perm_inv' n f) = f. +Proof. + rewrite (perm_inv'_eq_of_perm_eq _ _ _ (perm_inv'_eq n f)). + cleanup_perm_inv. +Qed. + +#[export] Hint Resolve perm_inv'_perm_inv + perm_inv'_perm_inv' perm_inv_perm_inv' : perm_inv_db. +#[export] Hint Rewrite perm_inv'_perm_inv_eq + perm_inv'_perm_inv'_eq + using + solve [auto with perm_db WF_Perm_db] : perm_inv_db. + +Lemma permutation_compose' n f g : + permutation n f -> permutation n g -> + permutation n (fun x => f (g x)). +Proof. + apply permutation_compose. +Qed. + +#[export] Hint Resolve permutation_compose permutation_compose' : perm_db. + +#[export] Hint Rewrite perm_inv_is_linv_of_permutation + perm_inv_is_rinv_of_permutation : perm_inv_db. + +Lemma perm_inv_eq_iff {n g} (Hg : permutation n g) + {k m} (Hk : k < n) (Hm : m < n) : + perm_inv n g k = m <-> k = g m. +Proof. + split; + [intros <- | intros ->]; + rewrite ?(perm_inv_is_rinv_of_permutation _ g Hg), + ?(perm_inv_is_linv_of_permutation _ g Hg); + easy. +Qed. + +Lemma perm_inv_eqb_iff {n g} (Hg : permutation n g) + {k m} (Hk : k < n) (Hm : m < n) : + (perm_inv n g k =? m) = (k =? g m). +Proof. + apply Bool.eq_iff_eq_true; + rewrite 2!Nat.eqb_eq; + now apply perm_inv_eq_iff. +Qed. + +Lemma perm_inv_ge n g k : + n <= perm_inv n g k -> n <= k. +Proof. + intros H. + bdestruct (n <=? k); [lia|]. + specialize (perm_inv_bounded n g k); lia. +Qed. + +Lemma compose_perm_inv_l n f g h + (Hf : permutation n f) (Hg : perm_bounded n g) + (Hh : perm_bounded n h) : + perm_eq n (perm_inv n f ∘ g) h <-> + perm_eq n g (f ∘ h). +Proof. + split; unfold compose. + - intros H k Hk. + rewrite <- H; cleanup_perm_inv. + - intros H k Hk. + rewrite H; cleanup_perm_inv. +Qed. + +Lemma compose_perm_inv_r n f g h + (Hf : permutation n f) (Hg : perm_bounded n g) + (Hh : perm_bounded n h) : + perm_eq n (g ∘ perm_inv n f) h <-> + perm_eq n g (h ∘ f). +Proof. + split; unfold compose. + - intros H k Hk. + rewrite <- H; cleanup_perm_inv. + - intros H k Hk. + rewrite H; cleanup_perm_inv. +Qed. + +Lemma compose_perm_inv_l' n f g h + (Hf : permutation n f) (Hg : perm_bounded n g) + (Hh : perm_bounded n h) : + perm_eq n h (perm_inv n f ∘ g) <-> + perm_eq n (f ∘ h) g. +Proof. + split; intros H; + apply perm_eq_sym, + compose_perm_inv_l, perm_eq_sym; + assumption. +Qed. + +Lemma compose_perm_inv_r' n f g h + (Hf : permutation n f) (Hg : perm_bounded n g) + (Hh : perm_bounded n h) : + perm_eq n h (g ∘ perm_inv n f) <-> + perm_eq n (h ∘ f) g. +Proof. + split; intros H; + apply perm_eq_sym, + compose_perm_inv_r, perm_eq_sym; + assumption. +Qed. + +Lemma compose_perm_inv'_l n (f g h : nat -> nat) + (Hf : permutation n f) (HWFf : WF_Perm n f) : + perm_inv' n f ∘ g = h <-> g = f ∘ h. +Proof. + split; [intros <- | intros ->]; + rewrite <- compose_assoc; + cleanup_perm_inv. +Qed. + +Lemma compose_perm_inv'_r n (f g h : nat -> nat) + (Hf : permutation n f) (HWFf : WF_Perm n f) : + g ∘ perm_inv' n f = h <-> g = h ∘ f. +Proof. + split; [intros <- | intros ->]; + rewrite compose_assoc; + cleanup_perm_inv. +Qed. + +Lemma compose_perm_inv'_l' n (f g h : nat -> nat) + (Hf : permutation n f) (HWFf : WF_Perm n f) : + h = perm_inv' n f ∘ g <-> f ∘ h = g. +Proof. + split; [intros -> | intros <-]; + rewrite <- compose_assoc; + cleanup_perm_inv. +Qed. + +Lemma compose_perm_inv'_r' n (f g h : nat -> nat) + (Hf : permutation n f) (HWFf : WF_Perm n f) : + h = g ∘ perm_inv' n f <-> h ∘ f = g. +Proof. + split; [intros -> | intros <-]; + rewrite compose_assoc; + cleanup_perm_inv. +Qed. + +Lemma perm_inv_perm_eq_iff n f g + (Hf : permutation n f) (Hg : permutation n g) : + perm_eq n (perm_inv n g) f <-> perm_eq n g (perm_inv n f). +Proof. + split; [intros <- | intros ->]; + cleanup_perm_inv. +Qed. + +Lemma perm_inv_compose {n f g} (Hf : permutation n f) (Hg : permutation n g) : + perm_eq n + (perm_inv n (f ∘ g)) + (perm_inv n g ∘ perm_inv n f). +Proof. + apply perm_eq_sym. + perm_eq_by_inv_inj (f ∘ g) n. + rewrite !compose_assoc. + cleanup_perm_inv. +Qed. + +#[export] Hint Resolve perm_inv_compose : perm_inv_db. +#[export] Hint Rewrite @perm_inv_compose + using solve [auto with perm_db] : perm_inv_db. + +Lemma perm_inv_compose_alt n f g + (Hf : permutation n f) (Hg : permutation n g) : + perm_eq n + (perm_inv n (fun x => f (g x))) + (fun x => perm_inv n g (perm_inv n f x))%prg. +Proof. + now apply perm_inv_compose. +Qed. + +Lemma perm_inv'_compose {n f g} + (Hf : permutation n f) (Hg : permutation n g) : + perm_inv' n (f ∘ g) = + perm_inv' n g ∘ perm_inv' n f. +Proof. + eq_by_WF_perm_eq n. + cleanup_perm_inv. +Qed. + +#[export] Hint Rewrite @perm_inv'_compose + using solve [auto with perm_db] : perm_inv_db. + +Lemma perm_inv_inj n f g : + permutation n f -> permutation n g -> + perm_eq n (perm_inv n f) (perm_inv n g) -> + perm_eq n f g. +Proof. + intros Hf Hg Hfg. + rewrite <- (perm_inv_perm_inv n f Hf). + rewrite Hfg. + rewrite perm_inv_perm_inv by easy. + easy. +Qed. + +(* Permute bounded predicates *) +Lemma forall_lt_iff n (P Q : nat -> Prop) + (HPQ : forall k, k < n -> P k <-> Q k) : + (forall k, k < n -> P k) <-> (forall k, k < n -> Q k). +Proof. + apply forall_iff; intros k. + apply impl_iff; intros Hk. + auto. +Qed. + +Lemma forall_lt_iff_permute n f (Hf : permutation n f) + (P : nat -> Prop) : + (forall k, k < n -> P k) <-> (forall k, k < n -> P (f k)). +Proof. + split; intros HP. + - intros k Hk. + apply HP. + auto with perm_db. + - intros k Hk. + generalize (HP (perm_inv n f k) (perm_inv_bounded n f k Hk)). + now rewrite perm_inv_is_rinv_of_permutation by easy. +Qed. + +Lemma forall_lt_iff_of_permute_l n f (Hf : permutation n f) + (P Q : nat -> Prop) (HPQ : forall k, k < n -> P (f k) <-> Q k) : + (forall k, k < n -> P k) <-> (forall k, k < n -> Q k). +Proof. + rewrite (forall_lt_iff_permute n f Hf). + apply forall_iff; intros k. + apply impl_iff; intros Hk. + now apply HPQ. +Qed. + +Lemma forall_lt_iff_of_permute_r n f (Hf : permutation n f) + (P Q : nat -> Prop) (HPQ : forall k, k < n -> P k <-> Q (f k)) : + (forall k, k < n -> P k) <-> (forall k, k < n -> Q k). +Proof. + symmetry. + apply (forall_lt_iff_of_permute_l n f Hf). + intros k Hk. + now rewrite HPQ. +Qed. + + +Lemma idn_inv n : + perm_eq n (perm_inv n idn) idn. +Proof. + perm_eq_by_inv_inj (fun k:nat => k) n. +Qed. + +#[export] Hint Resolve idn_inv : perm_inv_db. + +Lemma idn_inv' n : + perm_inv' n idn = idn. +Proof. + permutation_eq_by_WF_inv_inj (fun k:nat=>k) n. +Qed. + +#[export] Hint Rewrite idn_inv' : perm_inv_db. + +Lemma swap_perm_defn a b n : a < n -> b < n -> + perm_eq n (swap_perm a b n) + (fun x => + if x =? a then b else + if x =? b then a else x). +Proof. + intros Ha Hb k Hk. + unfold swap_perm. + bdestructΩ'. +Qed. + +Lemma swap_perm_same a n : + swap_perm a a n = idn. +Proof. + unfold swap_perm. + apply functional_extensionality; intros k. + bdestructΩ'. +Qed. + +Lemma swap_perm_left a b n : a < n -> + swap_perm a b n a = b. +Proof. + unfold swap_perm; bdestructΩ'. +Qed. + +Lemma swap_perm_right a b n : b < n -> + swap_perm a b n b = a. +Proof. + unfold swap_perm; bdestructΩ'. +Qed. + +Lemma swap_perm_neither a b n x : x <> a -> x <> b -> + swap_perm a b n x = x. +Proof. + unfold swap_perm; bdestructΩ'. +Qed. + +Lemma swap_perm_comm a b n : + swap_perm a b n = swap_perm b a n. +Proof. + apply functional_extensionality; intros k. + unfold swap_perm. + bdestructΩ'. +Qed. + +Lemma swap_perm_WF a b n : + WF_Perm n (swap_perm a b n). +Proof. + intros k Hk. + unfold swap_perm. + bdestructΩ'. +Qed. + +Lemma swap_perm_bounded a b n : a < n -> b < n -> + perm_bounded n (swap_perm a b n). +Proof. + intros Ha Hb k Hk. + unfold swap_perm. + bdestructΩ'. +Qed. + +Lemma swap_perm_invol a b n : a < n -> b < n -> + (swap_perm a b n) ∘ (swap_perm a b n) = idn. +Proof. + intros Ha Hb. + unfold compose. + apply functional_extensionality; intros k. + unfold swap_perm. + bdestructΩ'. +Qed. + +#[export] Hint Rewrite swap_perm_same : perm_cleanup_db. +#[export] Hint Resolve swap_perm_WF : WF_Perm_db. +#[export] Hint Resolve swap_perm_bounded : perm_bounded_db. +#[export] Hint Rewrite swap_perm_invol using lia : perm_inv_db. + +Lemma swap_perm_big a b n : n <= a -> n <= b -> + perm_eq n (swap_perm a b n) idn. +Proof. + intros Ha Hb k Hk. + unfold swap_perm. + bdestructΩ'. +Qed. + +#[export] Hint Rewrite swap_perm_big using lia : perm_cleanup_db. + +Lemma swap_perm_big_eq a b n : + n <= a -> n <= b -> + swap_perm a b n = idn. +Proof. + intros. + eq_by_WF_perm_eq n. + cleanup_perm. +Qed. + +Lemma swap_perm_permutation a b n : a < n -> b < n -> + permutation n (swap_perm a b n). +Proof. + intros Ha Hb. + perm_by_inverse (swap_perm a b n). +Qed. + +Lemma swap_perm_S_permutation a n (Ha : S a < n) : + permutation n (swap_perm a (S a) n). +Proof. + apply swap_perm_permutation; lia. +Qed. + +#[export] Hint Resolve swap_perm_permutation : perm_db. +#[export] Hint Resolve swap_perm_S_permutation : perm_db. + +Lemma swap_perm_permutation_alt a b n : + n <= a -> n <= b -> + permutation n (swap_perm a b n). +Proof. + intros Ha Hb. + cleanup_perm. +Qed. + + +Lemma swap_perm_inv a b n : a < n -> b < n -> + perm_eq n (perm_inv n (swap_perm a b n)) + (swap_perm a b n). +Proof. + intros Ha Hb. + perm_eq_by_inv_inj (swap_perm a b n) n. +Qed. + +#[export] Hint Resolve swap_perm_inv : perm_inv_db. +#[export] Hint Rewrite swap_perm_inv using lia : perm_inv_db. + +Lemma swap_perm_inv' a b n : a < n -> b < n -> + perm_inv' n (swap_perm a b n) = + swap_perm a b n. +Proof. + intros. + eq_by_WF_perm_eq n; cleanup_perm_inv. +Qed. + +#[export] Hint Rewrite swap_perm_inv' using lia : perm_inv_db. + +Lemma swap_perm_even_S_even_permutation a n : + Nat.even a = true -> Nat.even n = true -> + permutation n (swap_perm a (S a) n). +Proof. + intros Ha Hn. + bdestruct (a + replace n with n' by lia; + apply swap_perm_bounded; + lia + auto with zarith : perm_bounded_db. *) + + + +Lemma compose_swap_perm a b c n : a < n -> b < n -> c < n -> + b <> c -> a <> c -> + (swap_perm a b n ∘ swap_perm b c n ∘ swap_perm a b n) = swap_perm a c n. +Proof. + intros Ha Hb Hc Hbc Hac. + eq_by_WF_perm_eq n. + rewrite <- swap_perm_inv at 1 by easy. + rewrite compose_assoc. + apply compose_perm_inv_l; [cleanup_perm_inv..|]. + rewrite !swap_perm_defn by easy. + unfold compose. + intros k Hk. + bdestructΩ'. +Qed. + +#[export] Hint Rewrite compose_swap_perm using lia : perm_cleanup_db. + + + + +(* Section on insertion_sort_list *) + +Lemma fswap_eq_compose_swap_perm {A} (f : nat -> A) n m o : n < o -> m < o -> + fswap f n m = f ∘ swap_perm n m o. +Proof. + intros Hn Hm. + apply functional_extensionality; intros k. + unfold compose, fswap, swap_perm. + bdestruct_all; easy. +Qed. + +Lemma fswap_perm_invol_n_permutation f n : permutation (S n) f -> + permutation n (fswap f (perm_inv (S n) f n) n). +Proof. + intros Hperm. + apply fswap_at_boundary_permutation. + - apply Hperm. + - apply perm_inv_bounded_S. + - apply perm_inv_is_rinv_of_permutation; auto. +Qed. + +Lemma perm_of_swap_list_WF l : swap_list_spec l = true -> + WF_Perm (length l) (perm_of_swap_list l). +Proof. + induction l. + - easy. + - simpl. + rewrite andb_true_iff. + intros [Ha Hl]. + intros k Hk. + unfold compose. + rewrite IHl; [|easy|lia]. + rewrite swap_perm_WF; easy. +Qed. + +Lemma invperm_of_swap_list_WF l : swap_list_spec l = true -> + WF_Perm (length l) (invperm_of_swap_list l). +Proof. + induction l. + - easy. + - simpl. + rewrite andb_true_iff. + intros [Ha Hl]. + intros k Hk. + unfold compose. + rewrite swap_perm_WF; [|easy]. + rewrite IHl; [easy|easy|lia]. +Qed. + +#[export] Hint Resolve perm_of_swap_list_WF invperm_of_swap_list_WF : WF_Perm_db. + +Lemma perm_of_swap_list_bounded l : swap_list_spec l = true -> + perm_bounded (length l) (perm_of_swap_list l). +Proof. + induction l; [easy|]. + simpl. + rewrite andb_true_iff. + intros [Ha Hl]. + intros k Hk. + unfold compose. + rewrite Nat.ltb_lt in Ha. + apply swap_perm_bounded; try lia. + bdestruct (k =? length l). + - subst; rewrite perm_of_swap_list_WF; try easy; lia. + - transitivity (length l); [|lia]. + apply IHl; [easy | lia]. +Qed. + +Lemma invperm_of_swap_list_bounded l : swap_list_spec l = true -> + perm_bounded (length l) (invperm_of_swap_list l). +Proof. + induction l; [easy|]. + simpl. + rewrite andb_true_iff. + intros [Ha Hl]. + rewrite Nat.ltb_lt in Ha. + intros k Hk. + unfold compose. + bdestruct (swap_perm a (length l) (S (length l)) k =? length l). + - rewrite H, invperm_of_swap_list_WF; [lia|easy|easy]. + - transitivity (length l); [|lia]. + apply IHl; [easy|]. + pose proof (swap_perm_bounded a (length l) (S (length l)) Ha (ltac:(lia)) k Hk). + lia. +Qed. + +#[export] Hint Resolve perm_of_swap_list_bounded + invperm_of_swap_list_bounded : perm_bounded_db. + + +Lemma invperm_linv_perm_of_swap_list l : swap_list_spec l = true -> + invperm_of_swap_list l ∘ perm_of_swap_list l = idn. +Proof. + induction l. + - easy. + - simpl. + rewrite andb_true_iff. + intros [Ha Hl]. + rewrite Combinators.compose_assoc, + <- (Combinators.compose_assoc _ _ _ _ (perm_of_swap_list _)). + rewrite swap_perm_invol, compose_idn_l. + + apply (IHl Hl). + + bdestructΩ (a + perm_of_swap_list l ∘ invperm_of_swap_list l = idn. +Proof. + induction l. + - easy. + - simpl. + rewrite andb_true_iff. + intros [Ha Hl]. + rewrite <- Combinators.compose_assoc, + (Combinators.compose_assoc _ _ _ _ (invperm_of_swap_list _)). + rewrite (IHl Hl). + rewrite compose_idn_r. + rewrite swap_perm_invol; [easy| |lia]. + bdestructΩ (a + perm_eq n (invperm_of_swap_list (insertion_sort_list n f) + ∘ perm_of_swap_list (insertion_sort_list n f)) idn. +Proof. + cleanup_perm_inv. +Qed. + +Lemma invperm_rinv_perm_of_insertion_sort_list n f : permutation n f -> + perm_eq n (perm_of_swap_list (insertion_sort_list n f) + ∘ invperm_of_swap_list (insertion_sort_list n f)) idn. +Proof. + cleanup_perm_inv. +Qed. + +#[export] Hint Resolve invperm_linv_perm_of_insertion_sort_list + invperm_rinv_perm_of_insertion_sort_list : perm_inv_db. + + +Lemma perm_of_insertion_sort_list_is_rinv n f : permutation n f -> + perm_eq n (f ∘ perm_of_swap_list (insertion_sort_list n f)) idn. +Proof. + revert f; + induction n; + intros f. + - easy. + - intros Hperm k Hk. + cbn -[perm_inv]. + rewrite length_insertion_sort_list. + bdestruct (k =? n). + + unfold compose. + rewrite perm_of_swap_list_WF; [ | + apply insertion_sort_list_is_swap_list | + rewrite length_insertion_sort_list; lia + ]. + unfold swap_perm. + bdestructΩ'; + [replace -> n at 1|]; + cleanup_perm. + + rewrite <- compose_assoc. + rewrite <- fswap_eq_compose_swap_perm by + auto with perm_bounded_db. + rewrite IHn; [easy| |lia]. + apply fswap_perm_invol_n_permutation, Hperm. +Qed. + +#[export] Hint Resolve perm_of_insertion_sort_list_is_rinv : perm_inv_db. +#[export] Hint Rewrite perm_of_insertion_sort_list_is_rinv + using solve [auto with perm_db] : perm_inv_db. + +Lemma perm_of_insertion_sort_list_WF n f : + WF_Perm n (perm_of_swap_list (insertion_sort_list n f)). +Proof. + rewrite <- (length_insertion_sort_list n f) at 1. + auto with WF_Perm_db perm_db. +Qed. + +Lemma invperm_of_insertion_sort_list_WF n f : + WF_Perm n (invperm_of_swap_list (insertion_sort_list n f)). +Proof. + rewrite <- (length_insertion_sort_list n f) at 1. + auto with WF_Perm_db perm_db. +Qed. + +#[export] Hint Resolve perm_of_insertion_sort_list_WF + invperm_of_swap_list_WF : WF_Perm_db. + + +Lemma perm_of_insertion_sort_list_perm_eq_perm_inv n f : permutation n f -> + perm_eq n (perm_of_swap_list (insertion_sort_list n f)) (perm_inv n f). +Proof. + intros Hperm. + apply (perm_bounded_rinv_injective_of_injective n f); try cleanup_perm_inv. + pose proof (perm_of_swap_list_bounded (insertion_sort_list n f) + (insertion_sort_list_is_swap_list n f)) as H. + rewrite (length_insertion_sort_list n f) in H. + exact H. +Qed. + +#[export] Hint Resolve + perm_of_insertion_sort_list_perm_eq_perm_inv : perm_inv_db. + +#[export] Hint Rewrite + perm_of_insertion_sort_list_perm_eq_perm_inv + using solve [auto with perm_db] : perm_inv_db. + + +Lemma perm_of_insertion_sort_list_eq_perm_inv' n f : permutation n f -> + perm_of_swap_list (insertion_sort_list n f) = + perm_inv' n f. +Proof. + intros Hf. + eq_by_WF_perm_eq n. + cleanup_perm_inv. +Qed. + +#[export] Hint Rewrite + perm_of_insertion_sort_list_eq_perm_inv' + using solve [auto with perm_db] : perm_inv_db. + + +Lemma perm_inv_of_insertion_sort_list_perm_eq n f : permutation n f -> + perm_eq n (perm_inv n (perm_of_swap_list (insertion_sort_list n f))) f. +Proof. + intros Hf. + cleanup_perm_inv. +Qed. + +#[export] Hint Resolve perm_inv_of_insertion_sort_list_perm_eq : perm_inv_db. +#[export] Hint Rewrite perm_inv_of_insertion_sort_list_perm_eq + using solve [auto with perm_db] : perm_inv_db. + +Lemma perm_inv'_of_insertion_sort_list_eq n f : + permutation n f -> WF_Perm n f -> + perm_inv' n (perm_of_swap_list (insertion_sort_list n f)) = f. +Proof. + intros. + eq_by_WF_perm_eq n. + cleanup_perm_inv. +Qed. + +#[export] Hint Rewrite perm_inv'_of_insertion_sort_list_eq + using solve [auto with perm_db WF_Perm_db] : perm_inv_db. + +Lemma perm_eq_perm_of_insertion_sort_list_of_perm_inv n f : permutation n f -> + perm_eq n f (perm_of_swap_list (insertion_sort_list n (perm_inv n f))). +Proof. + intros Hf. + cleanup_perm_inv. +Qed. + +Lemma insertion_sort_list_S n f : + insertion_sort_list (S n) f = + (perm_inv (S n) f n) :: + (insertion_sort_list n (fswap f (perm_inv (S n) f n) n)). +Proof. easy. Qed. + +Lemma perm_of_swap_list_cons a l : + perm_of_swap_list (a :: l) = + swap_perm a (length l) (S (length l)) ∘ perm_of_swap_list l. +Proof. easy. Qed. + +Lemma invperm_of_swap_list_cons a l : + invperm_of_swap_list (a :: l) = + invperm_of_swap_list l ∘ swap_perm a (length l) (S (length l)). +Proof. easy. Qed. + +Lemma perm_of_insertion_sort_list_S n f : + perm_of_swap_list (insertion_sort_list (S n) f) = + swap_perm (perm_inv (S n) f n) n (S n) ∘ + perm_of_swap_list (insertion_sort_list n (fswap f (perm_inv (S n) f n) n)). +Proof. + rewrite insertion_sort_list_S, perm_of_swap_list_cons. + rewrite length_insertion_sort_list. + easy. +Qed. + +Lemma invperm_of_insertion_sort_list_S n f : + invperm_of_swap_list (insertion_sort_list (S n) f) = + invperm_of_swap_list (insertion_sort_list n (fswap f (perm_inv (S n) f n) n)) + ∘ swap_perm (perm_inv (S n) f n) n (S n). +Proof. + rewrite insertion_sort_list_S, invperm_of_swap_list_cons. + rewrite length_insertion_sort_list. + easy. +Qed. + +Lemma perm_of_swap_list_permutation l : swap_list_spec l = true -> + permutation (length l) (perm_of_swap_list l). +Proof. + intros Hsw. + induction l; + [ exists idn; easy |]. + simpl in *. + rewrite andb_true_iff, Nat.ltb_lt in Hsw. + destruct Hsw. + pose proof (fun f => permutation_of_le_permutation_WF f + (length l) (S (length l)) ltac:(lia)). + cleanup_perm_inv. +Qed. + +Lemma invperm_of_swap_list_permutation l : swap_list_spec l = true -> + permutation (length l) (invperm_of_swap_list l). +Proof. + intros Hsw. + induction l; + [ exists idn; easy |]. + simpl in *. + rewrite andb_true_iff, Nat.ltb_lt in Hsw. + destruct Hsw. + pose proof (fun f => permutation_of_le_permutation_WF f + (length l) (S (length l)) ltac:(lia)). + cleanup_perm_inv. +Qed. + +Lemma perm_of_insertion_sort_list_permutation n f: + permutation n (perm_of_swap_list (insertion_sort_list n f)). +Proof. + rewrite <- (length_insertion_sort_list n f) at 1. + apply perm_of_swap_list_permutation. + apply insertion_sort_list_is_swap_list. +Qed. + +Lemma invperm_of_insertion_sort_list_permutation n f: + permutation n (invperm_of_swap_list (insertion_sort_list n f)). +Proof. + rewrite <- (length_insertion_sort_list n f) at 1. + apply invperm_of_swap_list_permutation. + apply insertion_sort_list_is_swap_list. +Qed. + +#[export] Hint Resolve + perm_of_swap_list_permutation + invperm_of_swap_list_permutation + perm_of_insertion_sort_list_permutation + invperm_of_insertion_sort_list_permutation : perm_db. + + + + + +Lemma perm_eq_invperm_of_insertion_sort_list n f : permutation n f -> + perm_eq n f (invperm_of_swap_list (insertion_sort_list n f)). +Proof. + intros Hperm. + perm_eq_by_inv_inj (perm_of_swap_list (insertion_sort_list n f)) n. +Qed. + +#[export] Hint Rewrite <- perm_eq_invperm_of_insertion_sort_list + using solve [auto with perm_db] : perm_inv_db. + +Lemma permutation_grow_l' n f : permutation (S n) f -> + perm_eq (S n) f (swap_perm (f n) n (S n) ∘ + perm_of_swap_list (insertion_sort_list n (fswap (perm_inv (S n) f) (f n) n))). +Proof. + intros Hperm. + rewrite (perm_eq_perm_of_insertion_sort_list_of_perm_inv _ _ Hperm) + at 1. + cbn -[perm_inv]. + rewrite length_insertion_sort_list, perm_inv_perm_inv by auto. + easy. +Qed. + +Lemma permutation_grow_r' n f : permutation (S n) f -> + perm_eq (S n) f ( + invperm_of_swap_list (insertion_sort_list n (fswap f (perm_inv (S n) f n) n)) + ∘ swap_perm (perm_inv (S n) f n) n (S n)). +Proof. + intros Hperm. + rewrite (perm_eq_invperm_of_insertion_sort_list _ _ Hperm) at 1. + cbn -[perm_inv]. + rewrite length_insertion_sort_list by auto. + easy. +Qed. + +Lemma permutation_grow_l n f : permutation (S n) f -> + exists g k, k < S n /\ + perm_eq (S n) f (swap_perm k n (S n) ∘ g) /\ permutation n g. +Proof. + intros Hperm. + eexists. + exists (f n). + split; [apply permutation_is_bounded; [easy | lia] | split]. + pose proof (perm_eq_perm_of_insertion_sort_list_of_perm_inv _ _ Hperm) as H. + rewrite perm_of_insertion_sort_list_S in H. + rewrite perm_inv_perm_inv in H by (easy || lia). + exact H. + auto with perm_db. +Qed. + +Lemma permutation_grow_r n f : permutation (S n) f -> + exists g k, k < S n /\ perm_eq (S n) f (g ∘ swap_perm k n (S n)) /\ permutation n g. +Proof. + intros Hperm. + eexists. + exists (perm_inv (S n) f n). + split; [apply permutation_is_bounded; [auto with perm_db | lia] | split]. + pose proof (perm_eq_invperm_of_insertion_sort_list _ _ Hperm) as H. + rewrite invperm_of_insertion_sort_list_S in H. + exact H. + auto with perm_db. +Qed. + + + +(* Section on stack_perms *) +Lemma stack_perms_left {n0 n1} {f g} {k} : + k < n0 -> stack_perms n0 n1 f g k = f k. +Proof. + intros Hk. + unfold stack_perms. + replace_bool_lia (k stack_perms n0 n1 f g k = g (k - n0) + n0. +Proof. + intros Hk. + unfold stack_perms. + replace_bool_lia (k stack_perms n0 n1 f g (k + n0) = g k + n0. +Proof. + intros Hk. + rewrite stack_perms_right; [|lia]. + replace (k + n0 - n0) with k by lia. + easy. +Qed. + +Lemma stack_perms_add_right {n0 n1} {f g} {k} : + k < n1 -> stack_perms n0 n1 f g (n0 + k) = g k + n0. +Proof. + rewrite Nat.add_comm. + exact stack_perms_right_add. +Qed. + +Lemma stack_perms_high {n0 n1} {f g} {k} : + n0 + n1 <= k -> (stack_perms n0 n1 f g) k = k. +Proof. + intros H. + unfold stack_perms. + replace_bool_lia (k if k if (¬ k + replace (WF_Perm n) with (WF_Perm (n0 + n1)) by (f_equal; lia); + apply stack_perms_WF : WF_Perm_db. + +Lemma stack_perms_bounded {n0 n1} {f g} : + perm_bounded n0 f -> perm_bounded n1 g -> + perm_bounded (n0 + n1) (stack_perms n0 n1 f g). +Proof. + intros Hf Hg. + intros k Hk. + specialize (Hf k). + specialize (Hg (k - n0)). + unfold stack_perms. + bdestructΩ'. +Qed. + +#[export] Hint Resolve stack_perms_bounded : perm_bounded_db. +#[export] Hint Extern 10 (perm_bounded ?n (stack_perms ?n0 ?n1 ?f ?g)) => + apply (perm_bounded_change_dims n (n0 + n1) ltac:(lia)); + apply stack_perms_bounded : perm_bounded_db. + +Lemma stack_perms_defn n0 n1 f g : + perm_eq (n0 + n1) (stack_perms n0 n1 f g) + (fun k => if k perm_eq n1 ==> eq as stack_perms_perm_eq_to_eq_proper. +Proof. + intros f f' Hf g g' Hg. + eq_by_WF_perm_eq (n0 + n1). + rewrite 2!stack_perms_defn. + intros k Hk. + bdestructΩ'; f_equal; auto with zarith. +Qed. + +Lemma stack_perms_compose {n0 n1} {f g} {f' g'} + (Hf' : perm_bounded n0 f') (Hg' : perm_bounded n1 g') : + (stack_perms n0 n1 f g ∘ stack_perms n0 n1 f' g' + = stack_perms n0 n1 (f ∘ f') (g ∘ g'))%prg. +Proof. + eq_by_WF_perm_eq (n0 + n1). + intros k Hk. + specialize (Hf' k). + specialize (Hg' (k - n0)). + autounfold with perm_unfold_db. + bdestructΩ'. + now rewrite Nat.add_sub. +Qed. + +#[export] Hint Rewrite @stack_perms_compose + using solve [auto with perm_db perm_bounded_db] : perm_inv_db. + +Lemma stack_perms_assoc {n0 n1 n2} {f g h} : + stack_perms (n0 + n1) n2 (stack_perms n0 n1 f g) h = + stack_perms n0 (n1 + n2) f (stack_perms n1 n2 g h). +Proof. + eq_by_WF_perm_eq (n0 + n1 + n2). + do 3 rewrite_strat innermost stack_perms_defn. + rewrite <- Nat.add_assoc. + rewrite stack_perms_defn. + intros k Hk. + rewrite Nat.add_assoc, Nat.sub_add_distr. + bdestructΩ'. +Qed. + +Lemma stack_perms_idn_of_left_right_idn {n0 n1} {f g} + (Hf : forall k, k < n0 -> f k = k) (Hg : forall k, k < n1 -> g k = k) : + stack_perms n0 n1 f g = idn. +Proof. + solve_modular_permutation_equalities. + - apply Hf; easy. + - rewrite Hg; lia. +Qed. + +#[export] Hint Resolve stack_perms_idn_of_left_right_idn + stack_perms_compose | 10 : perm_inv_db. + + +Lemma stack_perms_idn_compose n0 n1 f g + (Hg : perm_bounded n1 g) : + stack_perms n0 n1 idn (f ∘ g) = + stack_perms n0 n1 idn f ∘ stack_perms n0 n1 idn g. +Proof. + cleanup_perm. +Qed. + +Lemma stack_perms_compose_idn n0 n1 f g + (Hg : perm_bounded n0 g) : + stack_perms n0 n1 (f ∘ g) idn = + stack_perms n0 n1 f idn ∘ stack_perms n0 n1 g idn. +Proof. + cleanup_perm. +Qed. + + +Lemma stack_perms_WF_idn n0 n1 f + (H : WF_Perm n0 f) : + stack_perms n0 n1 f idn = f. +Proof. + solve_modular_permutation_equalities; + rewrite H; lia. +Qed. + +#[export] Hint Rewrite stack_perms_WF_idn + using (solve [auto with WF_Perm_db]) : perm_inv_db. + +Lemma stack_perms_rinv {n0 n1} {f g} {finv ginv} + (Hf: forall k, k < n0 -> (f k < n0 /\ finv k < n0 /\ finv (f k) = k /\ f (finv k) = k)) + (Hg: forall k, k < n1 -> (g k < n1 /\ ginv k < n1 /\ ginv (g k) = k /\ g (ginv k) = k)) : + stack_perms n0 n1 f g ∘ stack_perms n0 n1 finv ginv = idn. +Proof. + unfold compose. + solve_modular_permutation_equalities. + 1-3: specialize (Hf _ H); lia. + - replace (ginv (k - n0) + n0 - n0) with (ginv (k - n0)) by lia. + assert (Hkn0: k - n0 < n1) by lia. + specialize (Hg _ Hkn0). + lia. + - assert (Hkn0: k - n0 < n1) by lia. + specialize (Hg _ Hkn0). + lia. +Qed. + +Lemma is_inv_iff_inv_is n f finv : + (forall k, k < n -> finv k < n /\ f k < n /\ f (finv k) = k /\ finv (f k) = k)%nat + <-> (forall k, k < n -> f k < n /\ finv k < n /\ finv (f k) = k /\ f (finv k) = k)%nat. +Proof. + split; intros H k Hk; specialize (H k Hk); easy. +Qed. + +Lemma stack_perms_linv {n0 n1} {f g} {finv ginv} + (Hf: forall k, k < n0 -> (f k < n0 /\ finv k < n0 /\ finv (f k) = k /\ f (finv k) = k)) + (Hg: forall k, k < n1 -> (g k < n1 /\ ginv k < n1 /\ ginv (g k) = k /\ g (ginv k) = k)) : + stack_perms n0 n1 finv ginv ∘ stack_perms n0 n1 f g = idn. +Proof. + now rewrite stack_perms_rinv + by now apply is_inv_iff_inv_is. +Qed. + +Lemma stack_perms_perm_eq_inv_of_perm_eq_inv {n0 n1} {f g} {finv ginv} + (Hf : perm_eq n0 (f ∘ finv) idn) + (Hg : perm_eq n1 (g ∘ ginv) idn) + (Hfinv : perm_bounded n0 finv) + (Hginv : perm_bounded n1 ginv) : + perm_eq (n0 + n1) + (stack_perms n0 n1 f g ∘ stack_perms n0 n1 finv ginv) + idn. +Proof. + rewrite stack_perms_compose by easy. + now rewrite Hf, Hg, stack_perms_idn_idn. +Qed. + +#[export] Hint Resolve stack_perms_perm_eq_inv_of_perm_eq_inv : perm_inv_db. + +Lemma stack_perms_inv_of_perm_eq_inv {n0 n1} {f g} {finv ginv} + (Hf : perm_eq n0 (f ∘ finv) idn) + (Hg : perm_eq n1 (g ∘ ginv) idn) + (Hfinv : perm_bounded n0 finv) + (Hginv : perm_bounded n1 ginv) : + stack_perms n0 n1 f g ∘ stack_perms n0 n1 finv ginv = idn. +Proof. + eq_by_WF_perm_eq (n0 + n1). + auto with perm_inv_db. +Qed. + +#[export] Hint Resolve stack_perms_inv_of_perm_eq_inv : perm_inv_db. + +#[export] Hint Resolve permutation_is_bounded : perm_bounded_db. + +Lemma stack_perms_permutation {n0 n1 f g} + (Hf : permutation n0 f) (Hg: permutation n1 g) : + permutation (n0 + n1) (stack_perms n0 n1 f g). +Proof. + perm_by_inverse (stack_perms n0 n1 (perm_inv n0 f) (perm_inv n1 g)). +Qed. + +#[export] Hint Resolve stack_perms_permutation : perm_db. +#[export] Hint Extern 10 (permutation ?n (stack_perms ?n0 ?n1 ?f ?g)) => + replace (permutation n) with (permutation (n0 + n1)) by (f_equal; lia); + apply stack_perms_permutation : perm_db. + +Lemma perm_inv_stack_perms n m f g + (Hf : permutation n f) (Hg : permutation m g) : + perm_eq (n + m) + (perm_inv (n + m) (stack_perms n m f g)) + (stack_perms n m (perm_inv n f) (perm_inv m g)). +Proof. + perm_eq_by_inv_inj (stack_perms n m f g) (n+m). +Qed. + +#[export] Hint Rewrite perm_inv_stack_perms + using solve [auto with perm_db] : perm_inv_db. + +Lemma stack_perms_proper {n0 n1} {f f' g g'} + (Hf : perm_eq n0 f f') (Hg : perm_eq n1 g g') : + perm_eq (n0 + n1) + (stack_perms n0 n1 f g) + (stack_perms n0 n1 f' g'). +Proof. + intros k Hk. + unfold stack_perms. + bdestructΩ'; [apply Hf | f_equal; apply Hg]; lia. +Qed. + +#[export] Hint Resolve stack_perms_proper : perm_inv_db. + +Lemma stack_perms_proper_eq {n0 n1} {f f' g g'} + (Hf : perm_eq n0 f f') (Hg : perm_eq n1 g g') : + stack_perms n0 n1 f g = + stack_perms n0 n1 f' g'. +Proof. + eq_by_WF_perm_eq (n0 + n1); cleanup_perm_inv. +Qed. + +#[export] Hint Resolve stack_perms_proper_eq : perm_inv_db. + +Lemma perm_inv'_stack_perms n m f g + (Hf : permutation n f) (Hg : permutation m g) : + perm_inv' (n + m) (stack_perms n m f g) = + stack_perms n m (perm_inv' n f) (perm_inv' m g). +Proof. + permutation_eq_by_WF_inv_inj (stack_perms n m f g) (n+m). +Qed. + +#[export] Hint Rewrite @perm_inv'_stack_perms + using solve [auto with perm_db] : perm_inv_db. + +Lemma stack_perms_diag_split n m f g + (Hg : perm_bounded m g) : + stack_perms n m f g = stack_perms n m f idn ∘ stack_perms n m idn g. +Proof. cleanup_perm. Qed. + +Lemma stack_perms_antidiag_split n m f g (Hf : perm_bounded n f) : + stack_perms n m f g = stack_perms n m idn g ∘ stack_perms n m f idn. +Proof. cleanup_perm. Qed. + + +Lemma contract_perm_perm_eq_of_perm_eq n f g a : + a < n -> perm_eq n f g -> + perm_eq (n - 1) (contract_perm f a) (contract_perm g a). +Proof. + intros Ha Hfg. + intros k Hk. + unfold contract_perm. + now rewrite !Hfg by lia. +Qed. + +#[export] Hint Resolve contract_perm_perm_eq_of_perm_eq : perm_inv_db. + +Add Parametric Morphism n : contract_perm with signature + perm_eq n ==> on_predicate_relation_l (fun k => k < n) eq ==> + perm_eq (n - 1) as compose_perm_perm_eq_proper. +Proof. + intros f g Hfg k l [? ->]. + now apply contract_perm_perm_eq_of_perm_eq. +Qed. + +#[export] Hint Extern 0 (_ < _) => + lia : proper_side_conditions_db. + +#[export] Hint Extern 2 (_ < _) => + solve [auto with perm_bounded_db perm_db] : proper_side_conditions_db. + +#[export] Hint Extern 5 (?f ?k < ?n) => + (* idtac "TRYNG" f k n; *) + let Hk := fresh "Hk" in + assert (Hk : k < n) by (easy + lia); + (* idtac "SECOND STEP"; *) + let Hfbdd := fresh "Hfbdd" in + assert (Hfbdd : perm_bounded f n) by + (auto with perm_bounded_db perm_db zarith); + (* idtac "THIRD STEP"; *) + exact (Hfbdd k Hk) : proper_side_conditions_db. + +Lemma contract_perm_bounded {n f} (Hf : perm_bounded n f) a : + a < n -> + perm_bounded (n - 1) (contract_perm f a). +Proof. + intros Ha k Hk. + pose proof (Hf a Ha). + pose proof (Hf k ltac:(lia)). + pose proof (Hf (k+1) ltac:(lia)). + unfold contract_perm. + bdestructΩ'. +Qed. + +#[export] Hint Resolve contract_perm_bounded : perm_bounded_db. + +Lemma contract_perm_compose a n (Ha : a < n) f g + (Hg : permutation n g) + (Hf : perm_inj n f) : + perm_eq (n - 1) (contract_perm f (g a) ∘ contract_perm g a) + (contract_perm (f ∘ g) a). +Proof. + intros k Hk. + pose proof (permutation_is_injective n g Hg k a ltac:(lia) Ha) as Hgka. + pose proof (permutation_is_injective n g Hg (k + 1) a ltac:(lia) Ha) + as HgSka. + pose proof (Hf a (g k) Ha (permutation_is_bounded n g Hg k ltac:(lia))) + as Hfk. + unfold contract_perm, compose. + bdestructΩ'_with ltac:(rewrite ?Nat.sub_add in * by lia; + try lia). +Qed. + +Lemma contract_perm_compose' a b n (Ha : a < n) f g + (Hg : permutation n g) + (Hf : perm_inj n f) + (Hb : b = g a) : + perm_eq (n - 1) (contract_perm f b ∘ contract_perm g a) + (contract_perm (f ∘ g) a). +Proof. + subst. + now apply contract_perm_compose. +Qed. + +Lemma contract_perm_idn a : contract_perm idn a = idn. +Proof. + unfold contract_perm. + solve_modular_permutation_equalities. +Qed. + +#[export] Hint Rewrite contract_perm_idn : perm_cleanup_db. + +Lemma contract_perm_permutation {n f} (Hf : permutation n f) a : + a < n -> + permutation (n - 1) (contract_perm f a). +Proof. + intros Ha. + pose proof (fun x y => permutation_eq_iff x y Hf) as Hfinj. + perm_by_inverse ((contract_perm (perm_inv n f) (f a))); + change (?f (?g k)) with ((f ∘ g) k); + match goal with |- ?f ?k = ?k => + enough (Heq : perm_eq (n - 1) f idn) by (exact (Heq k Hk)) + end; + rewrite contract_perm_compose'; cleanup_perm. +Qed. + +#[export] Hint Resolve contract_perm_permutation : perm_db. + +Lemma contract_perm_inv n f (Hf : permutation n f) a : + a < n -> + perm_eq (n - 1) + (perm_inv (n - 1) (contract_perm f a)) + (contract_perm (perm_inv n f) (f a)). +Proof. + intros Ha. + perm_eq_by_inv_inj (contract_perm f a) (n - 1). + rewrite contract_perm_compose; cleanup_perm. +Qed. + +#[export] Hint Resolve contract_perm_inv : perm_inv_db. + +#[export] Hint Rewrite contract_perm_inv + using solve [lia + auto with perm_db perm_bounded_db] : perm_inv_db. + +Lemma contract_perm_big n a f (Ha : n <= a) (Hfa : n <= f a) + (Hf : perm_bounded n f) : + perm_eq (n - 1) (contract_perm f a) f. +Proof. + intros k Hk. + unfold contract_perm. + pose proof (Hf k ltac:(lia)). + bdestructΩ'. +Qed. + +Lemma contract_perm_big_WF n a f (Ha : n <= a) (HfWF : WF_Perm n f) + (Hf : perm_bounded n f) : + perm_eq (n - 1) (contract_perm f a) f. +Proof. + apply contract_perm_big; try rewrite HfWF; easy. +Qed. + +Lemma contract_perm_WF n f a : WF_Perm n f -> a < n -> f a < n -> + WF_Perm (n - 1) (contract_perm f a). +Proof. + intros Hf Ha Hfa. + intros k Hk. + unfold contract_perm. + bdestruct (a =? f a); [ + replace <- (f a) in *; + bdestructΩ'; + rewrite ?Hf in * by lia; try lia| + ]. + bdestructΩ'; + rewrite ?Hf in * by lia; lia. +Qed. + +#[export] Hint Extern 0 (WF_Perm _ (contract_perm _ _)) => + apply contract_perm_WF; + [| auto using permutation_is_bounded + with perm_bounded_db..] : WF_Perm_db. + +Lemma contract_perm_inv' {n f} (Hf : permutation n f) a : + WF_Perm n f -> + a < n -> + perm_inv' (n - 1) (contract_perm f a) = + contract_perm (perm_inv' n f) (f a). +Proof. + intros Hfwf Ha. + eq_by_WF_perm_eq (n-1). + cleanup_perm. +Qed. + +#[export] Hint Rewrite @contract_perm_inv' + using (match goal with + | |- WF_Perm _ _ => solve [auto with WF_Perm_db perm_db perm_inv_db] + | |- _ => auto with perm_db + end) : perm_inv_db. + + + +(* Section on rotr / rotl *) +Lemma rotr_defn n m : + perm_eq n (rotr n m) (fun k => (k + m) mod n). +Proof. + intros k Hk. + unfold rotr. + bdestructΩ'. +Qed. + +Lemma rotl_defn n m : + perm_eq n (rotl n m) (fun k => (k + (n - m mod n)) mod n). +Proof. + intros k Hk. + unfold rotl. + bdestructΩ'. +Qed. + +Lemma rotr_WF n m : + WF_Perm n (rotr n m). +Proof. intros k Hk. unfold rotr. bdestruct_one; lia. Qed. + +Lemma rotl_WF n m : + WF_Perm n (rotl n m). +Proof. intros k Hk. unfold rotl. bdestruct_one; lia. Qed. + +#[export] Hint Resolve rotr_WF rotl_WF : WF_Perm_db. + +Lemma rotr_bdd {n m} : + forall k, k < n -> (rotr n m) k < n. +Proof. + intros. unfold rotr. bdestruct_one; [lia|]. + apply Nat.mod_upper_bound; lia. +Qed. + +Lemma rotl_bdd {n m} : + forall k, k < n -> (rotl n m) k < n. +Proof. + intros. unfold rotl. bdestruct_one; [lia|]. + apply Nat.mod_upper_bound; lia. +Qed. + +#[export] Hint Resolve rotr_bdd rotl_bdd : perm_bounded_db. + +Lemma rotr_rotl_inv n m : + ((rotr n m) ∘ (rotl n m) = idn)%prg. +Proof. + apply functional_extensionality; intros k. + unfold compose, rotl, rotr. + bdestruct (n <=? k); [bdestructΩ'|]. + assert (Hn0 : n <> 0) by lia. + bdestruct_one. + - pose proof (Nat.mod_upper_bound (k + (n - m mod n)) n Hn0) as Hbad. + lia. (* contradict Hbad *) + - rewrite Nat.Div0.add_mod_idemp_l. + rewrite <- Nat.add_assoc. + replace (n - m mod n + m) with + (n - m mod n + (n * (m / n) + m mod n)) by + (rewrite <- (Nat.div_mod m n Hn0); easy). + pose proof (Nat.mod_upper_bound m n Hn0). + replace (n - m mod n + (n * (m / n) + m mod n)) with + (n * (1 + m / n)) by lia. + rewrite Nat.mul_comm, Nat.Div0.mod_add. + apply Nat.mod_small, H. +Qed. + +Lemma rotl_rotr_inv n m : + ((rotl n m) ∘ (rotr n m) = idn)%prg. +Proof. + apply functional_extensionality; intros k. + unfold compose, rotl, rotr. + bdestruct (n <=? k); [bdestructΩ'|]. + assert (Hn0 : n <> 0) by lia. + bdestruct_one. + - pose proof (Nat.mod_upper_bound (k + m) n Hn0) as Hbad. + lia. (* contradict Hbad *) + - rewrite Nat.Div0.add_mod_idemp_l. + rewrite <- Nat.add_assoc. + rewrite (Nat.div_mod_eq m n) at 1. + pose proof (Nat.mod_upper_bound m n Hn0). + replace ((n * (m / n) + m mod n) + (n - m mod n)) with + (n * (1 + m / n)) by lia. + rewrite Nat.mul_comm, Nat.Div0.mod_add. + apply Nat.mod_small, H. +Qed. + +#[export] Hint Rewrite rotr_rotl_inv rotl_rotr_inv : perm_inv_db. + +Lemma rotr_perm {n m} : permutation n (rotr n m). +Proof. + perm_by_inverse (rotl n m). +Qed. + +Lemma rotl_perm {n m} : permutation n (rotl n m). +Proof. + perm_by_inverse (rotr n m). +Qed. + +#[export] Hint Resolve rotr_perm rotl_perm : perm_db. + +Lemma rotr_0_r n : rotr n 0 = idn. +Proof. + apply functional_extensionality; intros k. + unfold rotr. + bdestructΩ'. + rewrite Nat.mod_small; lia. +Qed. + +Lemma rotl_0_r n : rotl n 0 = idn. +Proof. + apply functional_extensionality; intros k. + unfold rotl. + bdestructΩ'. + rewrite Nat.Div0.mod_0_l, Nat.sub_0_r. + replace (k + n) with (k + 1 * n) by lia. + rewrite Nat.Div0.mod_add, Nat.mod_small; lia. +Qed. + +Lemma rotr_0_l k : rotr 0 k = idn. +Proof. + apply functional_extensionality; intros a. + unfold rotr. + bdestructΩ'. +Qed. + +Lemma rotl_0_l k : rotl 0 k = idn. +Proof. + apply functional_extensionality; intros a. + unfold rotl. + bdestructΩ'. +Qed. + +#[export] Hint Rewrite rotr_0_r rotl_0_r rotr_0_l rotl_0_l : perm_cleanup_db. + +Lemma rotr_rotr n k l : + ((rotr n k) ∘ (rotr n l) = rotr n (k + l))%prg. +Proof. + apply functional_extensionality; intros a. + unfold compose, rotr. + symmetry. + bdestructΩ'. + - pose proof (Nat.mod_upper_bound (a + l) n); lia. + - rewrite Nat.Div0.add_mod_idemp_l. + f_equal; lia. +Qed. + +Lemma rotl_rotl n k l : + ((rotl n k) ∘ (rotl n l) = rotl n (k + l))%prg. +Proof. + permutation_eq_by_WF_inv_inj (rotr n (k + l)) n. + rewrite Nat.add_comm, <- rotr_rotr, <- compose_assoc, + (compose_assoc _ _ _ _ (rotr n l)). + cleanup_perm. +Qed. + +#[export] Hint Rewrite rotr_rotr rotl_rotl : perm_cleanup_db. + +Lemma rotr_n n : rotr n n = idn. +Proof. + apply functional_extensionality; intros a. + unfold rotr. + bdestructΩ'. + replace (a + n) with (a + 1 * n) by lia. + destruct n; [lia|]. + rewrite Nat.Div0.mod_add. + rewrite Nat.mod_small; easy. +Qed. + +#[export] Hint Rewrite rotr_n : perm_cleanup_db. + +Lemma rotr_eq_rotr_mod n k : rotr n k = rotr n (k mod n). +Proof. + eq_by_WF_perm_eq n. + intros a Ha. + unfold rotr. + simplify_bools_lia_one_kernel. + now rewrite Nat.Div0.add_mod, (Nat.mod_small a n Ha). +Qed. + +Lemma rotl_n n : rotl n n = idn. +Proof. + permutation_eq_by_WF_inv_inj (rotr n n) n. +Qed. + +#[export] Hint Rewrite rotl_n : perm_cleanup_db. + +Lemma rotl_eq_rotl_mod n k : rotl n k = rotl n (k mod n). +Proof. + permutation_eq_by_WF_inv_inj (rotr n k) n. + rewrite rotr_eq_rotr_mod; cleanup_perm_inv. +Qed. + +Lemma rotr_eq_rotl_sub n k : + rotr n k = rotl n (n - k mod n). +Proof. + rewrite rotr_eq_rotr_mod. + permutation_eq_by_WF_inv_inj (rotl n (k mod n)) n. + cleanup_perm. + destruct n; [rewrite rotl_0_l; easy|]. + pose proof (Nat.mod_upper_bound k (S n)). + rewrite Nat.sub_add by lia. + cleanup_perm. +Qed. + +Lemma rotl_eq_rotr_sub n k : + rotl n k = rotr n (n - k mod n). +Proof. + permutation_eq_by_WF_inv_inj (rotr n k) n. + destruct n; [cbn; rewrite 2!rotr_0_l, compose_idn_l; easy|]. + rewrite (rotr_eq_rotr_mod _ k), rotr_rotr. + pose proof (Nat.mod_upper_bound k (S n)). + rewrite Nat.sub_add by lia. + cleanup_perm. +Qed. + +Lemma rotl_eq_rotr_sub_le n k : k <= n -> + rotl n k = rotr n (n - k). +Proof. + intros Hk. + bdestruct (k =? n); + [subst; rewrite Nat.sub_diag; cleanup_perm|]. + now rewrite rotl_eq_rotr_sub, Nat.mod_small by lia. +Qed. + + +Lemma rotr_add_n_l n k : + rotr n (n + k) = rotr n k. +Proof. + rewrite rotr_eq_rotr_mod. + rewrite Nat.add_comm, mod_add_n_r. + now rewrite <- rotr_eq_rotr_mod. +Qed. + +Lemma rotr_add_n_r n k : + rotr n (k + n) = rotr n k. +Proof. + rewrite rotr_eq_rotr_mod. + rewrite mod_add_n_r. + now rewrite <- rotr_eq_rotr_mod. +Qed. + +#[export] Hint Rewrite rotr_add_n_r rotr_add_n_l : perm_cleanup_db. + +Lemma rotr_inv n m : + perm_eq n (perm_inv n (rotr n m)) (rotl n m). +Proof. + perm_eq_by_inv_inj (rotr n m) n. +Qed. + +Lemma rotr_inv' n m : + perm_inv' n (rotr n m) = rotl n m. +Proof. + permutation_eq_by_WF_inv_inj (rotr n m) n. +Qed. + +Lemma rotl_inv n m : + perm_eq n (perm_inv n (rotl n m)) (rotr n m). +Proof. + perm_eq_by_inv_inj (rotl n m) n. +Qed. + +Lemma rotl_inv' n m : + perm_inv' n (rotl n m) = rotr n m. +Proof. + permutation_eq_by_WF_inv_inj (rotl n m) n. +Qed. + +#[export] Hint Resolve rotr_inv rotl_inv : perm_inv_db. +#[export] Hint Rewrite rotr_inv rotr_inv' rotl_inv rotl_inv' : perm_inv_db. + +Lemma rotr_decomp n m : + rotr n m = + fun k => + if n <=? k then k else + if k + m mod n if n + m <=? k then k else + if k if m + n <=? k then k else + if k if k if k + apply (perm_bounded_change_dims n (p * q) + ltac:(show_pow2_le + unify_pows_two; nia)); + apply big_swap_perm_bounded : perm_bounded_db. + +Lemma big_swap_perm_WF p q : + WF_Perm (p + q) (big_swap_perm p q). +Proof. + intros k Hk. + unfold big_swap_perm. + bdestructΩ'. +Qed. + +#[export] Hint Resolve big_swap_perm_WF : WF_Perm_db. +#[export] Hint Extern 1 (WF_Perm ?n (big_swap_perm ?p ?q)) => + replace (WF_Perm n) with (WF_Perm (p + q)) by (f_equal; lia); + apply big_swap_perm_WF : WF_Perm_db. + +Lemma big_swap_perm_invol p q : + big_swap_perm p q ∘ big_swap_perm q p = idn. +Proof. + eq_by_WF_perm_eq (p + q). + intros k Hk. + unfold big_swap_perm, compose. + bdestructΩ'. +Qed. + +#[export] Hint Rewrite big_swap_perm_invol : perm_inv_db perm_cleanup_db. + +Lemma big_swap_perm_permutation p q : + permutation (p + q) (big_swap_perm p q). +Proof. + perm_by_inverse (big_swap_perm q p). +Qed. + +#[export] Hint Resolve big_swap_perm_permutation : perm_db. +#[export] Hint Extern 1 (permutation ?n (big_swap_perm ?p ?q)) => + replace (permutation n) with (permutation (p + q)) by (f_equal; lia); + apply big_swap_perm_permutation : perm_db. + +Lemma big_swap_perm_inv p q : + perm_eq (p + q) (perm_inv (p + q) (big_swap_perm p q)) + (big_swap_perm q p). +Proof. + perm_eq_by_inv_inj (big_swap_perm p q) (p + q). +Qed. + +Lemma big_swap_perm_inv_change_dims p q n : n = p + q -> + perm_eq n (perm_inv n (big_swap_perm p q)) + (big_swap_perm q p). +Proof. + intros ->. + apply big_swap_perm_inv. +Qed. + +#[export] Hint Resolve big_swap_perm_inv : perm_inv_db. + +#[export] Hint Rewrite big_swap_perm_inv : perm_inv_db. +#[export] Hint Rewrite big_swap_perm_inv_change_dims + using lia : perm_inv_db. + +Lemma big_swap_perm_inv' p q : + perm_inv' (p + q) (big_swap_perm p q) = + big_swap_perm q p. +Proof. + eq_by_WF_perm_eq (p + q). + cleanup_perm_inv. +Qed. + +Lemma big_swap_perm_inv'_change_dims p q n : n = p + q -> + perm_inv' n (big_swap_perm p q) = + big_swap_perm q p. +Proof. + intros ->. + apply big_swap_perm_inv'. +Qed. + +#[export] Hint Rewrite big_swap_perm_inv' : perm_inv_db. +#[export] Hint Rewrite big_swap_perm_inv'_change_dims + using lia : perm_inv_db. + +Lemma big_swap_perm_0_l q : + big_swap_perm 0 q = idn. +Proof. + eq_by_WF_perm_eq q. + intros k Hk. + unfold big_swap_perm; bdestructΩ'. +Qed. + +Lemma big_swap_perm_0_r p : + big_swap_perm p 0 = idn. +Proof. + eq_by_WF_perm_eq p. + intros k Hk. + unfold big_swap_perm; bdestructΩ'. +Qed. + +#[export] Hint Rewrite big_swap_perm_0_l big_swap_perm_0_r : perm_cleanup_db. + +Lemma big_swap_perm_ltb_r n m k : + big_swap_perm n m k n - S k). +Proof. + intros k Hk. + unfold reflect_perm. + bdestructΩ'. +Qed. + +Lemma reflect_perm_invol n k : + reflect_perm n (reflect_perm n k) = k. +Proof. + unfold reflect_perm; bdestructΩ'. +Qed. + +Lemma reflect_perm_invol_eq n : + reflect_perm n ∘ reflect_perm n = idn. +Proof. + apply functional_extensionality, reflect_perm_invol. +Qed. + +#[export] Hint Rewrite reflect_perm_invol reflect_perm_invol_eq : perm_inv_db. + +Lemma reflect_perm_bounded n : perm_bounded n (reflect_perm n). +Proof. + intros k Hk. + unfold reflect_perm; bdestructΩ'. +Qed. + +#[export] Hint Resolve reflect_perm_bounded : perm_bounded_db. + +Lemma reflect_perm_permutation n : + permutation n (reflect_perm n). +Proof. + perm_by_inverse (reflect_perm n). +Qed. + +#[export] Hint Resolve reflect_perm_permutation : perm_db. + +Lemma reflect_perm_WF n : WF_Perm n (reflect_perm n). +Proof. + intros k Hk; unfold reflect_perm; bdestructΩ'. +Qed. + +#[export] Hint Resolve reflect_perm_WF : WF_Perm_db. + +Lemma reflect_perm_inv n : + perm_eq n (perm_inv n (reflect_perm n)) (reflect_perm n). +Proof. + perm_eq_by_inv_inj (reflect_perm n) n. +Qed. + +#[export] Hint Resolve reflect_perm_inv : perm_inv_db. +#[export] Hint Rewrite reflect_perm_inv : perm_inv_db. + +Lemma reflect_perm_inv' n : + perm_inv' n (reflect_perm n) = reflect_perm n. +Proof. + eq_by_WF_perm_eq n. + cleanup_perm_inv. +Qed. + +#[export] Hint Rewrite reflect_perm_inv : perm_inv_db. + +Lemma swap_perm_conj_reflect_eq a b n + (Ha : a < n) (Hb : b < n) : + reflect_perm n ∘ swap_perm a b n ∘ reflect_perm n = + swap_perm (n - S a) (n - S b) n. +Proof. + eq_by_WF_perm_eq n. + rewrite reflect_perm_defn at 1. + rewrite reflect_perm_defn, 2!swap_perm_defn by lia. + intros k Hk. + unfold compose. + replace_bool_lia (n - S k =? a) (k =? n - S a). + replace_bool_lia (n - S k =? b) (k =? n - S b). + bdestructΩ'. +Qed. + + + +Lemma swap_block_perm_sub padl padm m a k : + m <= k -> + swap_block_perm padl padm a (k - m) = + swap_block_perm (m + padl) padm a k - m. +Proof. + intros Hk. + unfold swap_block_perm. + bdestructΩ'. +Qed. + +Lemma swap_block_perm_invol padl padm a k : + swap_block_perm padl padm a (swap_block_perm padl padm a k) = k. +Proof. + unfold swap_block_perm. + bdestructΩ'. +Qed. + +Lemma swap_block_perm_invol_eq padl padm a : + swap_block_perm padl padm a ∘ swap_block_perm padl padm a = idn. +Proof. + apply functional_extensionality, swap_block_perm_invol. +Qed. + +#[export] Hint Rewrite swap_block_perm_invol + swap_block_perm_invol_eq : perm_inv_db. + +Lemma swap_block_perm_bounded padl padm padr a : + perm_bounded (padl + a + padm + a + padr) (swap_block_perm padl padm a). +Proof. + intros k Hk. + unfold swap_block_perm. + bdestructΩ'. +Qed. + +Lemma swap_block_perm_bounded_alt padl padm padr a : + perm_bounded (padr + a + padm + a + padl) (swap_block_perm padl padm a). +Proof. + replace (padr + a + padm + a + padl) + with (padl + a + padm + a + padr) by lia. + apply swap_block_perm_bounded. +Qed. + +#[export] Hint Resolve swap_block_perm_bounded + swap_block_perm_bounded_alt : perm_bounded_db. + +Lemma swap_block_perm_permutation padl padm padr a : + permutation (padl + a + padm + a + padr) (swap_block_perm padl padm a). +Proof. + perm_by_inverse (swap_block_perm padl padm a). +Qed. + +Lemma swap_block_perm_permutation_alt padl padm padr a : + permutation (padr + a + padm + a + padl) (swap_block_perm padl padm a). +Proof. + perm_by_inverse (swap_block_perm padl padm a). +Qed. + +#[export] Hint Resolve swap_block_perm_permutation + swap_block_perm_permutation_alt : perm_db. + +Lemma swap_block_perm_WF padl padm padr a : + WF_Perm (padl + a + padm + a + padr) (swap_block_perm padl padm a). +Proof. + unfold swap_block_perm. + intros k Hk; bdestructΩ'. +Qed. + +Lemma swap_block_perm_WF_alt padl padm padr a : + WF_Perm (padl + a + padm + a + padr) (swap_block_perm padr padm a). +Proof. + unfold swap_block_perm. + intros k Hk; bdestructΩ'. +Qed. + +#[export] Hint Resolve swap_block_perm_WF + swap_block_perm_WF_alt : WF_Perm_db. + +Lemma swap_block_perm_inv padl padm padr a : + perm_eq (padl + a + padm + a + padr) + (perm_inv (padl + a + padm + a + padr) + (swap_block_perm padl padm a)) + (swap_block_perm padl padm a). +Proof. + perm_eq_by_inv_inj (swap_block_perm padl padm a) + (padl + a + padm + a + padr). +Qed. + +Lemma swap_block_perm_inv_alt padl padm padr a : + perm_eq (padl + a + padm + a + padr) + (perm_inv (padl + a + padm + a + padr) + (swap_block_perm padr padm a)) + (swap_block_perm padr padm a). +Proof. + perm_eq_by_inv_inj (swap_block_perm padr padm a) + (padl + a + padm + a + padr). +Qed. + +#[export] Hint Resolve swap_block_perm_inv + swap_block_perm_inv_alt : perm_inv_db. + +Lemma swap_block_perm_inv' padl padm padr a : + perm_inv' (padl + a + padm + a + padr) + (swap_block_perm padl padm a) = + swap_block_perm padl padm a. +Proof. + eq_by_WF_perm_eq (padl + a + padm + a + padr). + cleanup_perm_inv. +Qed. + +Lemma swap_block_perm_inv'_alt padl padm padr a : + perm_inv' (padl + a + padm + a + padr) + (swap_block_perm padr padm a) = + swap_block_perm padr padm a. +Proof. + eq_by_WF_perm_eq (padl + a + padm + a + padr). + cleanup_perm_inv. +Qed. + +#[export] Hint Rewrite swap_block_perm_inv' + swap_block_perm_inv'_alt : perm_inv_db. + +Lemma swap_block_perm_decomp_eq padl padr padm a : + swap_block_perm padl padm a = + stack_perms padl (a + padm + a + padr) idn + (stack_perms (a + padm + a) padr + ((stack_perms (a + padm) a (rotr (a + padm) a) idn) ∘ + rotr (a + padm + a) (a + padm)) idn). +Proof. + rewrite 2!stack_perms_WF_idn by + eauto using monotonic_WF_Perm with WF_Perm_db zarith. + rewrite 2!rotr_decomp. + pose proof (Nat.mod_small (a + padm) (a + padm + a)) as Hsm. + pose proof (Nat.mod_small (a) (a + padm)) as Hsm'. + pose proof (Nat.mod_upper_bound (a + padm) (a + padm + a)) as Hl. + pose proof (Nat.mod_upper_bound (a) (a + padm)) as Hl'. + assert (Hpadm0: padm = 0 -> a mod (a + padm) = 0) by + (intros ->; rewrite Nat.add_0_r, Nat.Div0.mod_same; easy). + rewrite stack_perms_idn_f. + unfold swap_block_perm. + apply functional_extensionality; intros k. + unfold compose. + bdestruct (a =? 0); + [subst; + rewrite ?Nat.add_0_r, ?Nat.add_0_l, ?Nat.Div0.mod_same in *; + bdestructΩ'|]. + rewrite Hsm in * by lia. + bdestruct (padm =? 0); + [subst; + rewrite ?Nat.add_0_r, ?Nat.add_0_l, ?Nat.Div0.mod_same in *; + bdestructΩ'|]. + rewrite Hsm' in * by lia. + bdestructΩ'. +Qed. + + + + +Lemma qubit_perm_to_nat_perm_defn n p : + perm_eq (2 ^ n) (qubit_perm_to_nat_perm n p) + (fun k => funbool_to_nat n ((nat_to_funbool n k) ∘ p)%prg). +Proof. + unfold qubit_perm_to_nat_perm. + intros k Hk. + simplify_bools_lia_one_kernel. + easy. +Qed. + +Lemma qubit_perm_to_nat_perm_WF n f : + WF_Perm (2^n) (qubit_perm_to_nat_perm n f). +Proof. + intros k Hk. + unfold qubit_perm_to_nat_perm. + bdestructΩ'. +Qed. + +#[export] Hint Resolve qubit_perm_to_nat_perm_WF : WF_Perm_db. +#[export] Hint Extern 100 (WF_Perm ?npow2 (qubit_perm_to_nat_perm ?n _)) => + replace (WF_Perm npow2) with (WF_Perm (2^n)) + by (show_pow2_le + unify_pows_two; nia) : WF_Perm_db. + +Add Parametric Morphism n : (qubit_perm_to_nat_perm n) with signature + perm_eq n ==> eq + as qubit_perm_to_nat_perm_perm_eq_to_eq_proper. +Proof. + intros f g Hfg. + eq_by_WF_perm_eq (2^n). + intros k Hk. + rewrite 2!qubit_perm_to_nat_perm_defn by easy. + apply funbool_to_nat_eq. + intros l Hl. + unfold compose. + now rewrite Hfg. +Qed. + +Lemma qubit_perm_to_nat_perm_bounded n f : + perm_bounded (2 ^ n) (qubit_perm_to_nat_perm n f). +Proof. + intros k Hk. + rewrite qubit_perm_to_nat_perm_defn by easy. + apply funbool_to_nat_bound. +Qed. + +#[export] Hint Resolve qubit_perm_to_nat_perm_bounded : perm_bounded_db. + +Lemma qubit_perm_to_nat_perm_compose n f g : + perm_bounded n f -> + (qubit_perm_to_nat_perm n f ∘ qubit_perm_to_nat_perm n g = + qubit_perm_to_nat_perm n (g ∘ f))%prg. +Proof. + intros Hf. + eq_by_WF_perm_eq (2^n). + rewrite 3!qubit_perm_to_nat_perm_defn. + unfold compose. + intros k Hk. + apply funbool_to_nat_eq. + intros y Hy. + now rewrite funbool_to_nat_inverse by auto. +Qed. + +#[export] Hint Rewrite qubit_perm_to_nat_perm_compose + using solve [auto with perm_bounded_db perm_db] : perm_inv_db. + +Lemma qubit_perm_to_nat_perm_compose_alt n f g (Hf : perm_bounded n f) k : + qubit_perm_to_nat_perm n f (qubit_perm_to_nat_perm n g k) = + qubit_perm_to_nat_perm n (g ∘ f)%prg k. +Proof. + now rewrite <- qubit_perm_to_nat_perm_compose. +Qed. + +#[export] Hint Rewrite qubit_perm_to_nat_perm_compose_alt + using solve [auto with perm_bounded_db perm_db] : perm_inv_db. + +Lemma qubit_perm_to_nat_perm_perm_eq_idn n : + perm_eq (2^n) (qubit_perm_to_nat_perm n idn) idn. +Proof. + rewrite qubit_perm_to_nat_perm_defn. + intros k Hk. + rewrite compose_idn_r. + now apply nat_to_funbool_inverse. +Qed. + +#[export] Hint Resolve qubit_perm_to_nat_perm_perm_eq_idn : perm_inv_db. + +Lemma qubit_perm_to_nat_perm_idn n : + qubit_perm_to_nat_perm n idn = idn. +Proof. + eq_by_WF_perm_eq (2^n). + cleanup_perm_inv. +Qed. + +#[export] Hint Rewrite qubit_perm_to_nat_perm_idn : perm_cleanup_db. + +Lemma qubit_perm_to_nat_perm_permutation : forall n p, + permutation n p -> permutation (2^n) (qubit_perm_to_nat_perm n p). +Proof. + intros n p Hp. + perm_by_inverse (qubit_perm_to_nat_perm n (perm_inv n p)). +Qed. + +#[export] Hint Resolve qubit_perm_to_nat_perm_permutation : perm_db. + +Lemma qubit_perm_to_nat_perm_inv n f (Hf : permutation n f) : + perm_eq (2^n) + (perm_inv (2^n) (qubit_perm_to_nat_perm n f)) + (qubit_perm_to_nat_perm n (perm_inv n f)). +Proof. + perm_eq_by_inv_inj (qubit_perm_to_nat_perm n f) (2^n). +Qed. + +#[export] Hint Resolve qubit_perm_to_nat_perm_inv : perm_inv_db. +#[export] Hint Rewrite qubit_perm_to_nat_perm_inv + using solve [auto with perm_db] : perm_inv_db. + +Lemma qubit_perm_to_nat_perm_inv' n f (Hf : permutation n f) : + perm_inv' (2^n) (qubit_perm_to_nat_perm n f) = + qubit_perm_to_nat_perm n (perm_inv' n f). +Proof. + eq_by_WF_perm_eq (2^n). + cleanup_perm_inv. +Qed. + +#[export] Hint Rewrite qubit_perm_to_nat_perm_inv' + using solve [auto with perm_db] : perm_inv_db. + +Lemma qubit_perm_to_nat_perm_inj n f g + (Hf : perm_bounded n f) : + perm_eq (2^n) (qubit_perm_to_nat_perm n f) (qubit_perm_to_nat_perm n g) -> + perm_eq n f g. +Proof. + rewrite 2!qubit_perm_to_nat_perm_defn. + intros H i Hi. + specialize (H (2^(n - S (f i))) ltac:(apply Nat.pow_lt_mono_r; + auto with perm_bounded_db zarith)). + unfold qubit_perm_to_nat_perm in H. + rewrite <- funbool_to_nat_eq_iff in H. + specialize (H i Hi). + revert H. + unfold compose. + rewrite Bits.nat_to_funbool_eq. + pose proof (Hf i Hi). + simplify_bools_lia_one_kernel. + rewrite 2!Nat.pow2_bits_eqb. + rewrite Nat.eqb_refl. + bdestructΩ'. +Qed. + + + +Lemma tensor_perms_defn n0 n1 f g : + perm_eq (n0 * n1) (tensor_perms n0 n1 f g) + (fun k => f (k / n1) * n1 + g (k mod n1)). +Proof. + intros k Hk. + unfold tensor_perms. + now simplify_bools_lia_one_kernel. +Qed. + +Lemma tensor_perms_defn_alt n0 n1 f g : + perm_eq (n1 * n0) (tensor_perms n0 n1 f g) + (fun k => f (k / n1) * n1 + g (k mod n1)). +Proof. + now rewrite Nat.mul_comm, tensor_perms_defn. +Qed. + +Lemma tensor_perms_bounded n0 n1 f g : + perm_bounded n0 f -> perm_bounded n1 g -> + perm_bounded (n0 * n1) (tensor_perms n0 n1 f g). +Proof. + intros Hf Hg k Hk. + rewrite tensor_perms_defn by easy. + pose proof (Hf (k / n1) ltac:(show_moddy_lt)). + pose proof (Hg (k mod n1) ltac:(show_moddy_lt)). + show_moddy_lt. +Qed. + +#[export] Hint Resolve tensor_perms_bounded : perm_bounded_db. +#[export] Hint Extern 10 (perm_bounded ?n01 (tensor_perms ?n0 ?n1 ?f ?g)) => + apply (perm_bounded_change_dims n01 (n0 * n1) + ltac:(show_pow2_le + unify_pows_two; nia)); + apply tensor_perms_bounded : perm_bounded_db. + +Lemma tensor_perms_WF n0 n1 f g : + WF_Perm (n0 * n1) (tensor_perms n0 n1 f g). +Proof. + intros k Hk. + unfold tensor_perms. + bdestructΩ'. +Qed. + +#[export] Hint Resolve tensor_perms_WF : WF_Perm_db. +#[export] Hint Extern 100 (WF_Perm ?n01 (tensor_perms ?n0 ?n1 ?f ?g)) => + replace (WF_Perm n01) with (WF_Perm (n0 * n1)) by + (f_equal; nia + show_pow2_le) : WF_Perm_db. + +Lemma tensor_perms_compose n0 n1 f0 f1 g0 g1 : + perm_bounded n0 f1 -> perm_bounded n1 g1 -> + tensor_perms n0 n1 f0 g0 ∘ tensor_perms n0 n1 f1 g1 = + tensor_perms n0 n1 (f0 ∘ f1) (g0 ∘ g1). +Proof. + intros Hf1 Hg1. + eq_by_WF_perm_eq (n0*n1). + rewrite 3!tensor_perms_defn. + intros k Hk. + unfold compose. + rewrite Nat.div_add_l by lia. + pose proof (Hf1 (k / n1) ltac:(show_moddy_lt)). + pose proof (Hg1 (k mod n1) ltac:(show_moddy_lt)). + rewrite (Nat.div_small (g1 _)), mod_add_l, Nat.mod_small by easy. + now rewrite Nat.add_0_r. +Qed. + +#[export] Hint Rewrite tensor_perms_compose : perm_cleanup_db perm_inv_db. + +Lemma tensor_perms_0_l n1 f g : + tensor_perms 0 n1 f g = idn. +Proof. + eq_by_WF_perm_eq (0 * n1). + easy. +Qed. + +Lemma tensor_perms_0_r n0 f g : + tensor_perms n0 0 f g = idn. +Proof. + eq_by_WF_perm_eq (n0 * 0). + intros k Hk; lia. +Qed. + +#[export] Hint Rewrite tensor_perms_0_l + tensor_perms_0_r : perm_cleanup_db perm_inv_db. + +Lemma tensor_perms_perm_eq_proper n0 n1 f f' g g' : + perm_eq n0 f f' -> perm_eq n1 g g' -> + tensor_perms n0 n1 f g = tensor_perms n0 n1 f' g'. +Proof. + intros Hf' Hg'. + eq_by_WF_perm_eq (n0 * n1). + rewrite 2!tensor_perms_defn. + intros k Hk. + now rewrite Hf', Hg' by show_moddy_lt. +Qed. + +#[export] Hint Resolve tensor_perms_perm_eq_proper : perm_inv_db. + +Add Parametric Morphism n0 n1 : (tensor_perms n0 n1) with signature + perm_eq n0 ==> perm_eq n1 ==> eq + as tensor_perms_perm_eq_to_eq_proper. +Proof. + intros; now apply tensor_perms_perm_eq_proper. +Qed. + +Lemma tensor_perms_idn_idn n0 n1 : + tensor_perms n0 n1 idn idn = idn. +Proof. + eq_by_WF_perm_eq (n0 * n1). + rewrite tensor_perms_defn. + intros k Hk. + pose proof (Nat.div_mod_eq k n1). + lia. +Qed. + +#[export] Hint Rewrite tensor_perms_idn_idn : perm_cleanup_db. + +Lemma tensor_perms_idn_idn' n0 n1 f g : + perm_eq n0 f idn -> perm_eq n1 g idn -> + perm_eq (n0 * n1) (tensor_perms n0 n1 f g) idn. +Proof. + intros -> ->. + cleanup_perm. +Qed. + +#[export] Hint Resolve tensor_perms_idn_idn' : perm_inv_db. + +Lemma tensor_perms_permutation n0 n1 f g + (Hf : permutation n0 f) (Hg : permutation n1 g) : + permutation (n0 * n1) (tensor_perms n0 n1 f g). +Proof. + perm_by_inverse (tensor_perms n0 n1 (perm_inv n0 f) (perm_inv n1 g)). +Qed. + +#[export] Hint Resolve tensor_perms_permutation : perm_db. + +Lemma tensor_perms_n_2_permutation n f g + (Hf : permutation n f) (Hg : permutation 2 g) : + permutation (n + n) (tensor_perms n 2 f g). +Proof. + replace (n + n) with (n * 2) by lia. + cleanup_perm. +Qed. + +#[export] Hint Resolve tensor_perms_n_2_permutation : perm_db. + +Lemma tensor_perms_inv n0 n1 f g + (Hf : permutation n0 f) (Hg : permutation n1 g) : + perm_eq (n0 * n1) + (perm_inv (n0 * n1) (tensor_perms n0 n1 f g)) + (tensor_perms n0 n1 (perm_inv n0 f) (perm_inv n1 g)). +Proof. + perm_eq_by_inv_inj (tensor_perms n0 n1 f g) (n0*n1). +Qed. + +#[export] Hint Resolve tensor_perms_inv : perm_inv_db. +#[export] Hint Rewrite tensor_perms_inv + using solve [auto with perm_db] : perm_inv_db. + +Lemma tensor_perms_inv' n0 n1 f g + (Hf : permutation n0 f) (Hg : permutation n1 g) : + perm_inv' (n0 * n1) (tensor_perms n0 n1 f g) = + tensor_perms n0 n1 (perm_inv' n0 f) (perm_inv' n1 g). +Proof. + permutation_eq_by_WF_inv_inj (tensor_perms n0 n1 f g) (n0*n1). +Qed. + +#[export] Hint Rewrite tensor_perms_inv' + using solve [auto with perm_db] : perm_inv_db. + +Lemma tensor_rotr_idn_eq_rotr_mul n m p : + tensor_perms n p (rotr n m) idn = + rotr (n * p) (m * p). +Proof. + eq_by_WF_perm_eq (n * p). + rewrite 2!rotr_defn, tensor_perms_defn. + intros k Hk. + symmetry. + rewrite (Nat.mul_comm n p). + rewrite Nat.Div0.mod_mul_r. + rewrite Nat.Div0.mod_add. + rewrite Nat.div_add by lia. + lia. +Qed. + + + +Lemma qubit_perm_to_nat_perm_stack_perms n0 n1 f g + (Hf : perm_bounded n0 f) (Hg : perm_bounded n1 g) : + qubit_perm_to_nat_perm (n0 + n1) (stack_perms n0 n1 f g) = + tensor_perms (2^n0) (2^n1) + (qubit_perm_to_nat_perm n0 f) + (qubit_perm_to_nat_perm n1 g). +Proof. + eq_by_WF_perm_eq (2^(n0+n1)). + rewrite stack_perms_defn. + rewrite !qubit_perm_to_nat_perm_defn, Nat.pow_add_r. + rewrite tensor_perms_defn. + intros k Hk. + rewrite funbool_to_nat_add_pow2_join. + apply funbool_to_nat_eq. + intros a Ha. + unfold compose. + bdestruct (a swap_2_perm k < 2. +Proof. + intros Hk. + repeat first [easy | destruct k | cbn; lia]. +Qed. + +#[export] Hint Resolve swap_2_perm_bounded : perm_bounded_db. + +Lemma swap_2_WF k : 1 < k -> swap_2_perm k = k. +Proof. + intros. + repeat first [easy | lia | destruct k]. +Qed. + +Lemma swap_2_WF_Perm : WF_Perm 2 swap_2_perm. +Proof. + intros k. + repeat first [easy | lia | destruct k]. +Qed. + +Global Hint Resolve swap_2_WF_Perm : WF_Perm_db. + +Lemma swap_2_perm_permutation : permutation 2 swap_2_perm. +Proof. + perm_by_inverse swap_2_perm. +Qed. + +Global Hint Resolve swap_2_perm_permutation : perm_db. + +Lemma swap_2_perm_inv : + perm_eq 2 + (perm_inv 2 swap_2_perm) swap_2_perm. +Proof. + perm_eq_by_inv_inj swap_2_perm 2. +Qed. + +Lemma swap_2_perm_inv' : + perm_inv' 2 swap_2_perm = swap_2_perm. +Proof. + permutation_eq_by_WF_inv_inj swap_2_perm 2. +Qed. + +#[export] Hint Resolve swap_2_perm_inv : perm_inv_db. +#[export] Hint Rewrite swap_2_perm_inv' : perm_inv_db. + + + + +Lemma kron_comm_perm_defn p q : + perm_eq (p * q) (kron_comm_perm p q) + (fun k => k mod p * q + k / p). +Proof. + intros k Hk. + unfold kron_comm_perm. + bdestructΩ'. +Qed. + +Lemma kron_comm_perm_defn_alt p q : + perm_eq (q * p) (kron_comm_perm p q) + (fun k => k mod p * q + k / p). +Proof. + intros k Hk. + unfold kron_comm_perm. + bdestructΩ'. +Qed. + +Lemma kron_comm_perm_WF p q : + WF_Perm (p * q) (kron_comm_perm p q). +Proof. + intros k Hk; unfold kron_comm_perm; bdestructΩ'. +Qed. + +Lemma kron_comm_perm_WF_alt p q : + WF_Perm (q * p) (kron_comm_perm p q). +Proof. + rewrite Nat.mul_comm; apply kron_comm_perm_WF. +Qed. + +#[export] Hint Resolve kron_comm_perm_WF kron_comm_perm_WF_alt : WF_Perm_db. +#[export] Hint Extern 10 (WF_Perm ?n (kron_comm_perm ?p ?q)) => + replace (WF_Perm n) with (WF_Perm (p * q)) by + (f_equal; show_pow2_le + unify_pows_two; nia); + apply kron_comm_perm_WF : WF_Perm_db. + +Lemma kron_comm_perm_bounded p q : + perm_bounded (p * q) (kron_comm_perm p q). +Proof. + intros k Hk. + unfold kron_comm_perm. + bdestructΩ'. + show_moddy_lt. +Qed. + +Lemma kron_comm_perm_bounded_alt p q : + perm_bounded (q * p) (kron_comm_perm p q). +Proof. + rewrite Nat.mul_comm. + apply kron_comm_perm_bounded. +Qed. + +#[export] Hint Resolve kron_comm_perm_bounded + kron_comm_perm_bounded_alt : perm_bounded_db. +#[export] Hint Extern 10 (perm_bounded ?n (kron_comm_perm ?p ?q)) => + apply (perm_bounded_change_dims n (p * q) + ltac:(show_pow2_le + unify_pows_two; nia)); + apply kron_comm_perm_bounded : perm_bounded_db. + +Lemma kron_comm_perm_pseudo_invol_perm_eq p q : + perm_eq (p * q) (kron_comm_perm p q ∘ kron_comm_perm q p)%prg idn. +Proof. + intros k Hk. + unfold compose, kron_comm_perm. + simplify_bools_lia_one_kernel. + simplify_bools_moddy_lia_one_kernel. + rewrite (Nat.add_comm _ (k/q)). + rewrite Nat.Div0.mod_add, Nat.div_add by show_nonzero. + rewrite Nat.Div0.div_div, Nat.mod_small by show_moddy_lt. + rewrite (Nat.div_small k (q*p)) by lia. + symmetry. + rewrite (Nat.div_mod_eq k q) at 1; lia. +Qed. + +#[export] Hint Resolve kron_comm_perm_pseudo_invol_perm_eq : perm_inv_db. + +Lemma kron_comm_perm_pseudo_invol_alt_perm_eq p q : + perm_eq (q * p) (kron_comm_perm p q ∘ kron_comm_perm q p)%prg idn. +Proof. + rewrite Nat.mul_comm; cleanup_perm_inv. +Qed. + +#[export] Hint Resolve kron_comm_perm_pseudo_invol_alt_perm_eq : perm_inv_db. + +Lemma kron_comm_perm_pseudo_invol p q : + kron_comm_perm p q ∘ kron_comm_perm q p = idn. +Proof. + eq_by_WF_perm_eq (p*q); cleanup_perm_inv. +Qed. + +#[export] Hint Rewrite kron_comm_perm_pseudo_invol : perm_inv_db. + +Lemma kron_comm_perm_permutation p q : + permutation (p * q) (kron_comm_perm p q). +Proof. + perm_by_inverse (kron_comm_perm q p). +Qed. + +Lemma kron_comm_perm_permutation_alt p q : + permutation (q * p) (kron_comm_perm p q). +Proof. + perm_by_inverse (kron_comm_perm q p). +Qed. + +#[export] Hint Resolve kron_comm_perm_permutation + kron_comm_perm_permutation_alt : perm_db. +#[export] Hint Extern 10 (permutation ?n (kron_comm_perm ?p ?q)) => + replace (permutation n) with (permutation (p * q)) by + (f_equal; show_pow2_le + unify_pows_two; nia); + apply kron_comm_perm_WF : perm_db. + +Lemma kron_comm_perm_inv p q : + perm_eq (p * q) + (perm_inv (p * q) (kron_comm_perm p q)) + (kron_comm_perm q p). +Proof. + perm_eq_by_inv_inj (kron_comm_perm p q) (p * q). +Qed. + +Lemma kron_comm_perm_inv_alt p q : + perm_eq (q * p) + (perm_inv (p * q) (kron_comm_perm p q)) + (kron_comm_perm q p). +Proof. + perm_eq_by_inv_inj (kron_comm_perm p q) (q * p). + rewrite Nat.mul_comm. + cleanup_perm_inv. +Qed. + +Lemma kron_comm_perm_swap_inv p q : + perm_eq (p * q) + (perm_inv (p * q) (kron_comm_perm q p)) + (kron_comm_perm p q). +Proof. + perm_eq_by_inv_inj (kron_comm_perm q p) (p * q). +Qed. + +Lemma kron_comm_perm_swap_inv_alt p q : + perm_eq (q * p) + (perm_inv (p * q) (kron_comm_perm q p)) + (kron_comm_perm p q). +Proof. + perm_eq_by_inv_inj (kron_comm_perm q p) (q * p). + rewrite Nat.mul_comm. + cleanup_perm_inv. +Qed. + +#[export] Hint Resolve kron_comm_perm_inv + kron_comm_perm_inv_alt + kron_comm_perm_swap_inv + kron_comm_perm_swap_inv_alt : perm_inv_db. +#[export] Hint Rewrite kron_comm_perm_inv + kron_comm_perm_inv_alt + kron_comm_perm_swap_inv + kron_comm_perm_swap_inv_alt : perm_inv_db. + +Lemma kron_comm_perm_inv' p q : + perm_inv' (p * q) (kron_comm_perm p q) = + kron_comm_perm q p. +Proof. + eq_by_WF_perm_eq (p * q). + cleanup_perm_inv. +Qed. + +Lemma kron_comm_perm_inv'_alt p q : + perm_inv' (q * p) (kron_comm_perm p q) = + kron_comm_perm q p. +Proof. + eq_by_WF_perm_eq (q * p). + cleanup_perm_inv. +Qed. + +#[export] Hint Rewrite kron_comm_perm_inv' + kron_comm_perm_inv'_alt : perm_inv_db. + + + + + + +Lemma stack_perms_big_swap_natural n0 n1 f g + (Hf : perm_bounded n0 f) (Hg : perm_bounded n1 g) : + stack_perms n0 n1 f g ∘ big_swap_perm n1 n0 = + big_swap_perm n1 n0 ∘ stack_perms n1 n0 g f. +Proof. + eq_by_WF_perm_eq (n0 + n1). + rewrite stack_perms_defn. + rewrite Nat.add_comm. + rewrite stack_perms_defn. + intros k Hk. + unfold compose, big_swap_perm. + pose proof (Hf (k - n1)). + pose proof (Hg k). + bdestructΩ'. + now rewrite Nat.add_sub. +Qed. + +Lemma stack_perms_rotr_natural n0 n1 f g + (Hf : perm_bounded n0 f) (Hg : perm_bounded n1 g) : + stack_perms n0 n1 f g ∘ rotr (n0 + n1) n0 = + rotr (n0 + n1) n0 ∘ stack_perms n1 n0 g f. +Proof. + rewrite rotr_add_l. + now apply stack_perms_big_swap_natural. +Qed. + +Lemma stack_perms_rotl_natural n0 n1 f g + (Hf : perm_bounded n0 f) (Hg : perm_bounded n1 g) : + stack_perms n0 n1 f g ∘ rotl (n0 + n1) n1 = + rotl (n0 + n1) n1 ∘ stack_perms n1 n0 g f. +Proof. + rewrite rotl_add_r. + now apply stack_perms_big_swap_natural. +Qed. + +Lemma tensor_perms_kron_comm_perm_natural n0 n1 f g + (Hf : perm_bounded n0 f) (Hg : perm_bounded n1 g) : + tensor_perms n0 n1 f g ∘ kron_comm_perm n0 n1 = + kron_comm_perm n0 n1 ∘ tensor_perms n1 n0 g f. +Proof. + eq_by_WF_perm_eq (n0 * n1). + rewrite tensor_perms_defn, kron_comm_perm_defn, tensor_perms_defn_alt. + intros k Hk. + unfold compose. + rewrite !Nat.div_add_l, !mod_add_l by lia. + pose proof (Hf (k mod n0) ltac:(show_moddy_lt)). + pose proof (Hg (k / n0) ltac:(show_moddy_lt)). + rewrite Nat.Div0.div_div, Nat.div_small, Nat.add_0_r by lia. + rewrite (Nat.mod_small (k / n0)) by (show_moddy_lt). + rewrite (Nat.mod_small (f _)), (Nat.div_small (f _)) by lia. + lia. +Qed. + + +Lemma inv_perm_eq_id_mid {padl padm padr f} + (Hf : permutation (padl + padm + padr) f) + (Hfidn : perm_eq_id_mid padl padm f) : + forall k, k < padl + padm + padr -> + padl <= f k < padl + padm -> f k = k. +Proof. + intros k Hk []. + apply (permutation_is_injective _ _ Hf); [lia..|]. + replace (f k) with (padl + (f k - padl)) by lia. + (* unfold perm_eq_id_mid in Hfidn. *) + apply Hfidn; lia. +Qed. + +Arguments compose_assoc [_ _ _ _]. + +Lemma expand_perm_id_mid_compose (f g : nat -> nat) (padl padm padr : nat) + (Hf : perm_bounded (padl + padr) f) + (Hg : perm_bounded (padl + padr) g) : + expand_perm_id_mid padl padm padr f ∘ expand_perm_id_mid padl padm padr g = + expand_perm_id_mid padl padm padr (f ∘ g). +Proof. + unfold expand_perm_id_mid. + (* cleanup_perm. *) + rewrite (compose_assoc _ (stack_perms _ _ idn (rotr _ padr))), + <- !(compose_assoc _ _ (stack_perms _ _ idn (rotr _ padr))). + cleanup_perm_inv. + cleanup_perm. + rewrite (Nat.add_comm padr padm). + cleanup_perm. + rewrite compose_assoc, <- (compose_assoc _ _ (stack_perms _ _ f _)). + cleanup_perm. +Qed. + +Lemma expand_perm_id_mid_eq_of_perm_eq {padl padr f g} + (Hfg : perm_eq (padl + padr) f g) padm : + expand_perm_id_mid padl padm padr f = expand_perm_id_mid padl padm padr g. +Proof. + unfold expand_perm_id_mid. + do 2 f_equal. + now apply stack_perms_proper_eq. +Qed. + +Lemma expand_perm_id_mid_permutation {padl padr f} + (Hf : permutation (padl + padr) f) padm : + permutation (padl + padm + padr) (expand_perm_id_mid padl padm padr f). +Proof. + unfold expand_perm_id_mid. + rewrite <- Nat.add_assoc. + apply permutation_compose; [|auto with perm_db]. + apply permutation_compose; [auto with perm_db|]. + replace (padl + (padm + padr)) with (padl + padr + padm) by lia. + auto with perm_db. +Qed. + +#[export] Hint Resolve expand_perm_id_mid_permutation : perm_db. + + + +Lemma contract_expand_perm_perm_eq_inv padl padm padr f + (Hf : perm_bounded (padl + padr) f) : + perm_eq (padl + padr) + (contract_perm_id_mid padl padm padr + (expand_perm_id_mid padl padm padr f)) + f. +Proof. + unfold contract_perm_id_mid, expand_perm_id_mid. + rewrite !compose_assoc. + cleanup_perm. + rewrite (Nat.add_comm padr padm). + rewrite <- !compose_assoc. + cleanup_perm. + rewrite (Nat.add_comm padr padm). + cleanup_perm. + intros k Hk. + now rewrite stack_perms_left by easy. +Qed. + + +Lemma contract_perm_id_mid_compose {padl padm padr f} + (Hf : perm_bounded (padl + padm + padr) f) g : + contract_perm_id_mid padl padm padr g ∘ contract_perm_id_mid padl padm padr f = + contract_perm_id_mid padl padm padr (g ∘ f). +Proof. + unfold contract_perm_id_mid. + rewrite (compose_assoc _ (stack_perms _ _ idn (rotr _ padm))), + <- !(compose_assoc _ _ (stack_perms _ _ idn (rotr _ padm))). + cleanup_perm. +Qed. + +Lemma contract_perm_id_mid_permutation_big {padl padm padr f} + (Hf : permutation (padl + padm + padr) f) : + permutation (padl + padm + padr) (contract_perm_id_mid padl padm padr f). +Proof. + unfold contract_perm_id_mid. + rewrite <- Nat.add_assoc in *. + auto with perm_db. +Qed. + +Lemma contract_perm_id_mid_permutation {padl padm padr f} + (Hf : permutation (padl + padm + padr) f) + (Hfid : perm_eq_id_mid padl padm f) : + permutation (padl + padr) (contract_perm_id_mid padl padm padr f). +Proof. + apply (permutation_of_le_permutation_idn_above _ _ _ + (contract_perm_id_mid_permutation_big Hf)); + [lia|]. + intros k []. + unfold contract_perm_id_mid. + unfold compose at 1. + rewrite stack_perms_right by lia. + rewrite rotr_add_l_eq. + do 2 simplify_bools_lia_one_kernel. + unfold compose. + rewrite (Nat.add_comm _ padl), Hfid by lia. + rewrite stack_perms_right by lia. + rewrite rotr_add_r_eq. + bdestructΩ'. +Qed. + +#[export] Hint Resolve contract_perm_id_mid_permutation_big + contract_perm_id_mid_permutation : perm_db. + + +Lemma expand_contract_perm_perm_eq_idn_inv {padl padm padr f} + (Hf : permutation (padl + padm + padr) f) + (Hfidn : perm_eq_id_mid padl padm f) : + perm_eq (padl + padm + padr) + ((expand_perm_id_mid padl padm padr + (contract_perm_id_mid padl padm padr f))) + f. +Proof. + unfold contract_perm_id_mid, expand_perm_id_mid. + intros k Hk. + rewrite (stack_perms_idn_f _ _ (rotr _ padr)) at 2. + unfold compose at 1. + simplify_bools_lia_one_kernel. + replace (if ¬ k nat) (i : nat) : nat := + match n with + | 0 => 0 + | S n' => if big_sum g n' <=? i then n' else + Nsum_index n' g i + end. + +Definition Nsum_offset (n : nat) (g : nat -> nat) (i : nat) : nat := + i - big_sum g (Nsum_index n g i). + +Add Parametric Morphism n : (Nsum_index n) with signature + perm_eq n ==> eq as Nsum_index_perm_eq_to_eq. +Proof. + intros g g' Hg. + apply functional_extensionality; intros k. + induction n; [easy|]. + - cbn -[big_sum]. + assert (Hg' : perm_eq n g g') by (hnf in *; auto). + rewrite IHn by auto. + now rewrite (big_sum_eq_bounded _ _ _ Hg'). +Qed. + +Lemma Nsum_index_total_bounded n g i : + Nsum_index n g i <= n. +Proof. + induction n; [cbn; lia|]. + simpl. + bdestructΩ'. +Qed. + +Lemma Nsum_index_bounded n g i : n <> 0 -> + Nsum_index n g i < n. +Proof. + induction n; [cbn; lia|]. + simpl. + destruct n; bdestructΩ'. +Qed. + +Lemma Nsum_index_spec n g i (Hi : i < big_sum g n) : + big_sum g (Nsum_index n g i) <= i < big_sum g (S (Nsum_index n g i)). +Proof. + induction n; [cbn in *; lia|]. + cbn. + bdestruct_one. + - cbn in *; lia. + - apply IHn; easy. +Qed. + +Lemma Nsum_index_spec_inv n g i k (Hk : k < n) : + big_sum g k <= i < big_sum g (S k) -> + Nsum_index n g i = k. +Proof. + fill_differences. + intros H. + induction x. + - rewrite Nat.add_0_r, Nat.add_comm. + simpl. + bdestructΩ'. + - rewrite Nat.add_succ_r. + simpl. + rewrite (big_sum_split _ k) by lia. + cbn in *. + bdestructΩ'. +Qed. + +Lemma Nsum_index_offset_spec n g i (Hi : i < big_sum g n) : + i = big_sum g (Nsum_index n g i) + Nsum_offset n g i + /\ Nsum_offset n g i < g (Nsum_index n g i). +Proof. + pose proof (Nsum_index_spec n g i Hi) as Hsum. + simpl in Hsum. + unfold Nsum_offset. + split; + lia. +Qed. + +Lemma Nsum_index_add_big_sum_l n dims i k + (Hi : i < dims k) (Hk : k < n) : + Nsum_index n dims (big_sum dims k + i) = + k. +Proof. + fill_differences. + induction x; [ + rewrite <- Nat.add_assoc, Nat.add_comm; + cbn; bdestructΩ'|]. + rewrite Nat.add_succ_r. + cbn. + rewrite (big_sum_split _ k _) by lia. + cbn. + simplify_bools_lia_one_kernel. + easy. +Qed. + +Lemma Nsum_offset_add_big_sum_l n dims i k + (Hi : i < dims k) (Hk : k < n) : + Nsum_offset n dims (big_sum dims k + i) = + i. +Proof. + unfold Nsum_offset. + rewrite Nsum_index_add_big_sum_l by auto. + lia. +Qed. + +Definition enlarge_permutation (n : nat) (f dims : nat -> nat) := + fun k => if big_sum dims n <=? k then k else + big_sum (dims ∘ f) + (perm_inv' n f (Nsum_index n dims k)) + + Nsum_offset n dims k. + + +Add Parametric Morphism n : (enlarge_permutation n) with signature + on_predicate_relation_l + (fun f => perm_bounded n f) + (perm_eq n) ==> perm_eq n ==> eq + as enlarge_permutation_perm_eq_to_eq. +Proof. + intros f f' [Hbdd Hf] dims dims' Hdims. + apply functional_extensionality; intros k. + unfold enlarge_permutation. + rewrite (big_sum_eq_bounded _ _ n Hdims). + bdestructΩ'. + bdestruct (n =? 0); [subst; cbn in *; lia|]. + f_equal. + - rewrite <- (perm_inv'_eq_of_perm_eq n f f' Hf). + assert (Hrw : perm_eq n (dims ∘ f) (dims' ∘ f')) by + now rewrite Hdims, Hf. + rewrite (Nsum_index_perm_eq_to_eq n _ _ Hdims). + apply big_sum_eq_bounded. + intros i Hi. + apply Hrw. + eapply Nat.lt_trans; [eassumption|]. + pose proof (Nsum_index_bounded n dims' k ltac:(auto)) as Hlt. + auto with perm_bounded_db. + - unfold Nsum_offset. + rewrite (Nsum_index_perm_eq_to_eq n _ _ Hdims). + f_equal. + apply big_sum_eq_bounded. + intros i Hi. + apply Hdims. + pose proof (Nsum_index_bounded n dims' k ltac:(auto)) as Hlt. + lia. +Qed. + + +Add Parametric Morphism n : (enlarge_permutation n) with signature + perm_eq n ==> eq ==> eq as enlarge_permutation_perm_eq_to_eq_to_eq. +Proof. + intros f f' Hf dims. + apply functional_extensionality; intros k. + unfold enlarge_permutation. + bdestructΩ'. + bdestruct (n =? 0); [subst; cbn in *; lia|]. + f_equal. + rewrite <- (perm_inv'_eq_of_perm_eq n f f' Hf). + apply big_sum_eq_bounded. + intros i Hi. + unfold compose. + f_equal. + apply Hf. + eapply Nat.lt_trans; [eassumption|]. + pose proof (Nsum_index_bounded n dims k ltac:(auto)) as Hlt. + auto with perm_bounded_db. +Qed. + +Lemma enlarge_permutation_add_big_sum_l n f dims i k + (Hi : i < dims k) (Hk : k < n) : + enlarge_permutation n f dims + (big_sum dims k + i) = + big_sum (dims ∘ f) (perm_inv' n f k) + i. +Proof. + unfold enlarge_permutation. + rewrite (big_sum_split n k dims Hk). + cbn. + simplify_bools_lia_one_kernel. + now rewrite Nsum_index_add_big_sum_l, + Nsum_offset_add_big_sum_l by auto. +Qed. + +Lemma enlarge_permutation_WF n f dims : + WF_Perm (big_sum dims n) (enlarge_permutation n f dims). +Proof. + intros k Hk. + unfold enlarge_permutation. + bdestructΩ'. +Qed. + +#[export] Hint Resolve enlarge_permutation_WF : WF_Perm_db. + +Lemma enlarge_permutation_compose' n f g dims dims' + (Hdims : perm_eq n (dims ∘ f) dims') + (Hf : permutation n f) (Hg : permutation n g) : + perm_eq (big_sum dims n) + (enlarge_permutation n g dims' ∘ enlarge_permutation n f dims) + (enlarge_permutation n (f ∘ g) dims). +Proof. + intros k Hk. + rewrite <- Hdims. + unfold compose at 1. + unfold enlarge_permutation at 2. + simplify_bools_lia_one_kernel. + assert (Hn : n <> 0) by (intros ->; cbn in *; lia). + pose proof (Nsum_index_bounded n dims k Hn). + rewrite enlarge_permutation_add_big_sum_l. + 3: auto with perm_bounded_db. + 2: { + pose proof (Nsum_index_offset_spec n dims k Hk). + unfold compose. + rewrite perm_inv'_eq by auto with perm_bounded_db. + rewrite perm_inv_is_rinv_of_permutation by auto. + lia. + } + rewrite Combinators.compose_assoc. + unfold enlarge_permutation. + simplify_bools_lia_one_kernel. + rewrite perm_inv'_compose by auto. + easy. +Qed. + +Lemma enlarge_permutation_bounded n f dims (Hf : permutation n f) : + perm_bounded (big_sum dims n) (enlarge_permutation n f dims). +Proof. + intros k Hk. + unfold enlarge_permutation. + simplify_bools_lia_one_kernel. + rewrite (Nsum_reorder n dims (f)) by auto with perm_db. + pose proof (Nsum_index_offset_spec n dims k Hk). + assert (Hn : n <> 0) by (intros ->; cbn in *; lia). + pose proof (Nsum_index_bounded n dims k Hn) as Hidx. + rewrite (big_sum_split n (perm_inv' n f (Nsum_index n dims k))) + by auto with perm_bounded_db. + unfold compose at 3. + rewrite perm_inv'_eq, perm_inv_is_rinv_of_permutation + by auto with perm_bounded_db. + cbn. + lia. +Qed. + +#[export] Hint Resolve enlarge_permutation_bounded : perm_bounded_db. + +Lemma enlarge_permutation_defn n f dims : + perm_eq (big_sum dims n) + (enlarge_permutation n f dims) + (fun k => big_sum (dims ∘ f) + (perm_inv' n f (Nsum_index n dims k)) + + Nsum_offset n dims k). +Proof. + intros k Hk. + unfold enlarge_permutation. + bdestructΩ'. +Qed. + +Lemma enlarge_permutation_idn n dims : + enlarge_permutation n idn dims = idn. +Proof. + eq_by_WF_perm_eq (big_sum dims n). + rewrite enlarge_permutation_defn. + intros k Hk. + rewrite idn_inv', compose_idn_r. + symmetry. + now apply Nsum_index_offset_spec. +Qed. + + +Lemma enlarge_permutation_permutation n f dims (Hf : permutation n f) : + permutation (big_sum dims n) (enlarge_permutation n f dims). +Proof. + rewrite permutation_defn. + assert (Hfinv : permutation n (perm_inv' n f)) by auto with perm_db. + exists (enlarge_permutation n (perm_inv' n f) (dims ∘ f)). + repeat split. + - auto with perm_bounded_db. + - rewrite (Nsum_reorder n dims _ Hf). + auto with perm_bounded_db perm_db. + - rewrite (Nsum_reorder n dims _ Hf). + rewrite enlarge_permutation_compose' by cleanup_perm_inv. + rewrite perm_inv'_eq, perm_inv_linv_of_permutation by assumption. + now rewrite enlarge_permutation_idn. + - rewrite enlarge_permutation_compose' by cleanup_perm_inv. + rewrite perm_inv'_eq, perm_inv_rinv_of_permutation by assumption. + now rewrite enlarge_permutation_idn. +Qed. + +#[export] Hint Resolve enlarge_permutation_permutation : perm_db. + +Lemma enlarge_permutation_inv n f dims (Hf : permutation n f) : + perm_eq (big_sum dims n) + (perm_inv (big_sum dims n) (enlarge_permutation n f dims)) + (enlarge_permutation n (perm_inv n f) (dims ∘ f)). +Proof. + perm_eq_by_inv_inj (enlarge_permutation n f dims) (big_sum dims n). + rewrite enlarge_permutation_compose' by auto_perm. + rewrite perm_inv_rinv_of_permutation by auto. + now rewrite enlarge_permutation_idn. +Qed. + +Lemma enlarge_permutation_inv' n f dims (Hf : permutation n f) : + perm_inv' (big_sum dims n) (enlarge_permutation n f dims) = + enlarge_permutation n (perm_inv' n f) (dims ∘ f). +Proof. + eq_by_WF_perm_eq (big_sum dims n); + [rewrite (Nsum_reorder n dims f Hf); auto_perm..|]. + rewrite 2!perm_inv'_eq. + now apply enlarge_permutation_inv. +Qed. + + +Definition swap_2_to_2_perm a b c d n := + fun k => + if n <=? k then k else + if b =? c then ( + if k =? a then b else + if k =? b then d else + if k =? d then a else k + ) else if a =? d then ( + if k =? a then c else + if k =? c then b else + if k =? b then a else k + ) else ( + if k =? a then c else + if k =? b then d else + if k =? c then a else + if k =? d then b else k). + +Lemma swap_2_to_2_perm_WF a b c d n : + WF_Perm n (swap_2_to_2_perm a b c d n). +Proof. + intros k Hk. + unfold swap_2_to_2_perm; bdestructΩ'. +Qed. + +#[export] Hint Resolve swap_2_to_2_perm_WF : WF_Perm_db. + +Lemma swap_2_to_2_perm_invol a b c d n + (Ha : a < n) (Hb : b < n) (Hc : c < n) (Hd : d < n) + (Hab : a <> b) (Hbc : b <> c) (Hcd : c <> d) + (Had : a <> d) : + swap_2_to_2_perm a b c d n ∘ swap_2_to_2_perm a b c d n = idn. +Proof. + eq_by_WF_perm_eq n. + intros k Hk. + unfold swap_2_to_2_perm, compose. + do 2 simplify_bools_lia_one_kernel. + bdestructΩ'. +Qed. + +#[export] Hint Resolve swap_2_to_2_perm_invol : perm_inv_db. + +Lemma swap_2_to_2_perm_bounded a b c d n + (Ha : a < n) (Hb : b < n) (Hc : c < n) (Hd : d < n) : + perm_bounded n (swap_2_to_2_perm a b c d n). +Proof. + intros k Hk. + unfold swap_2_to_2_perm. + simplify_bools_lia_one_kernel. + bdestructΩ'. +Qed. + +#[export] Hint Resolve swap_2_to_2_perm_bounded : perm_bounded_db. + +Lemma swap_2_to_2_perm_permutation a b c d n + (Ha : a < n) (Hb : b < n) (Hc : c < n) (Hd : d < n) + (Hab : a <> b) (Hcd : c <> d) : + permutation n (swap_2_to_2_perm a b c d n). +Proof. + bdestruct (b =? c); + [|bdestruct (a =? d)]. + - exists (swap_2_to_2_perm d b b a n). + intros k Hk; repeat split; + unfold swap_2_to_2_perm; + do 2 simplify_bools_lia_one_kernel; + bdestructΩ'. + - exists (swap_2_to_2_perm a c b a n). + intros k Hk; repeat split; + unfold swap_2_to_2_perm; + do 2 simplify_bools_lia_one_kernel; + bdestructΩ'. + - perm_by_inverse (swap_2_to_2_perm a b c d n). +Qed. + +#[export] Hint Resolve swap_2_to_2_perm_permutation : perm_db. + +Lemma swap_2_to_2_perm_first a b c d n (Ha : a < n) : + swap_2_to_2_perm a b c d n a = c. +Proof. + unfold swap_2_to_2_perm; bdestructΩ'. +Qed. + +Lemma swap_2_to_2_perm_second a b c d n (Ha : b < n) (Hab : a <> b) : + swap_2_to_2_perm a b c d n b = d. +Proof. + unfold swap_2_to_2_perm. + bdestructΩ'. +Qed. + + + +Lemma perm_eq_of_small_eq_idn n m f (Hm : n <= m) + (Hf : permutation m f) (Hfeq : perm_eq n f idn) : + perm_eq m f (stack_perms n (m - n) idn (fun k => f (k + n) - n)). +Proof. + assert (Hfeqinv : forall k, k < m -> f k < n -> k < n). 1:{ + intros k Hk Hfk. + enough (f k = k) by lia. + apply (permutation_is_injective m f Hf); [lia..|]. + now apply Hfeq. + } + assert (Hfbig : forall k, n <= k < m -> n <= f k). 1: { + intros k []. + bdestructΩ (n <=? f k). + specialize (Hfeqinv k); lia. + } + intros k Hk. + bdestruct (k n <= f k. +Proof. + assert (Hfeqinv : forall k, k < m -> f k < n -> k < n). 1:{ + intros k Hk Hfk. + enough (f k = k) by lia. + apply (permutation_is_injective m f Hf); [lia..|]. + now apply Hfeq. + } + intros k []. + bdestructΩ (n <=? f k). + specialize (Hfeqinv k); lia. +Qed. + +Lemma perm_inv_perm_eq_idn_of_perm_eq_idn_up_to n m f (Hm : n <= m) + (Hf : permutation m f) (Hfeq : perm_eq n f idn) : + perm_eq n (perm_inv m f) idn. +Proof. + intros k Hk. + apply (permutation_is_injective m f Hf); [auto with perm_bounded_db..|]. + cleanup_perm. + symmetry. + now apply Hfeq. +Qed. + +Lemma perm_shift_permutation_of_small_eq_idn n m f (Hm : n <= m) + (Hf : permutation m f) (Hfeq : perm_eq n f idn) : + permutation (m - n) (fun k => f (k + n) - n). +Proof. + pose proof (perm_big_of_small_eq_idn n m f Hm Hf Hfeq) as Hfbig. + pose proof (perm_big_of_small_eq_idn n m _ Hm (perm_inv_permutation m f Hf) + (perm_inv_perm_eq_idn_of_perm_eq_idn_up_to n m f Hm Hf Hfeq)) + as Hfinvbig. + exists (fun k => (perm_inv m f (k + n) - n)). + intros k Hk; repeat split. + - pose proof (permutation_is_bounded m f Hf (k + n)). + lia. + - pose proof (perm_inv_bounded m f (k + n)). + lia. + - rewrite Nat.sub_add by (apply Hfbig; lia). + cleanup_perm; + lia. + - rewrite Nat.sub_add by (apply Hfinvbig; lia). + cleanup_perm; + lia. +Qed. + +#[export] Hint Resolve perm_shift_permutation_of_small_eq_idn : perm_db. \ No newline at end of file diff --git a/PermutationMatrices.v b/PermutationMatrices.v new file mode 100644 index 0000000..7830e71 --- /dev/null +++ b/PermutationMatrices.v @@ -0,0 +1,1534 @@ +Require Import VectorStates. +Require Import Kronecker. +Require Export PermutationsBase. +Require Import PermutationAutomation. +Require Import PermutationInstances. +Require Import Modulus. +Require Import Pad. +Require Import Complex. +Import Setoid. + +(** Implementation of permutations as matrices and facts about those matrices. **) + +(** * Prerequisite lemmas **) + +Lemma basis_vector_equiv_e_i : forall n k, + basis_vector n k ≡ e_i k. +Proof. + intros n k i j Hi Hj. + unfold basis_vector, e_i. + bdestructΩ'. +Qed. + +Lemma basis_vector_eq_e_i : forall n k, (k < n)%nat -> + basis_vector n k = e_i k. +Proof. + intros n k Hk. + rewrite <- mat_equiv_eq_iff by auto with wf_db. + apply basis_vector_equiv_e_i. +Qed. + +Lemma vector_equiv_basis_comb : forall n (y : Vector n), + y ≡ big_sum (G:=Vector n) (fun i => y i O .* @e_i n i) n. +Proof. + intros n y. + intros i j Hi Hj. + replace j with O by lia; clear j Hj. + symmetry. + rewrite Msum_Csum. + apply big_sum_unique. + exists i. + repeat split; try easy. + - unfold ".*", e_i; bdestructΩ'; now Csimpl. + - intros l Hl Hnk. + unfold ".*", e_i; bdestructΩ'; now Csimpl. +Qed. + +Lemma vector_eq_basis_comb : forall n (y : Vector n), + WF_Matrix y -> + y = big_sum (G:=Vector n) (fun i => y i O .* @e_i n i) n. +Proof. + intros n y Hwfy. + apply mat_equiv_eq; auto with wf_db. + apply vector_equiv_basis_comb. +Qed. + +Lemma Mmult_if_r {n m o} (A : Matrix n m) (B B' : Matrix m o) (b : bool) : + A × (if b then B else B') = + if b then A × B else A × B'. +Proof. + now destruct b. +Qed. + +Lemma Mmult_if_l {n m o} (A A' : Matrix n m) (B : Matrix m o) (b : bool) : + (if b then A else A') × B = + if b then A × B else A' × B. +Proof. + now destruct b. +Qed. + + +Definition direct_sum' {n m o p : nat} (A : Matrix n m) (B : Matrix o p) : + Matrix (n+o) (m+p) := + (fun i j => if (i WF_Matrix B -> + A .⊕' B = A .⊕ B. +Proof. + intros n m o p A B HA HB. + apply mat_equiv_eq; [|apply WF_direct_sum|]; auto with wf_db. + intros i j Hi Hj. + unfold direct_sum, direct_sum'. + bdestruct_all; try lia + easy; + rewrite HA by lia; easy. +Qed. + +Lemma direct_sum'_simplify_mat_equiv {n m o p} : forall (A B : Matrix n m) + (C D : Matrix o p), A ≡ B -> C ≡ D -> direct_sum' A C ≡ direct_sum' B D. +Proof. + intros A B C D HAB HCD i j Hi Hj. + unfold direct_sum'. + bdestruct (i + @direct_sum' n m o p A Zero = A. +Proof. + intros HA. + prep_matrix_equality. + unfold direct_sum', Zero. + symmetry. + bdestructΩ'_with ltac: + (try lia; try rewrite HA by lia; try reflexivity). +Qed. + +Lemma direct_sum_Mscale {n m p q} (A : Matrix n m) + (B : Matrix p q) (c : C) : + (c .* A) .⊕' (c .* B) = c .* (A .⊕' B). +Proof. + apply mat_equiv_eq; auto with wf_db. + intros i j Hi Hj. + autounfold with U_db. + bdestruct_all; simpl; now Csimpl. +Qed. + +Lemma ei_direct_sum_split n m k : + @e_i (n + m) k = + (if k (@mat_equiv o p) + ==> (@mat_equiv (n+o) (m+p)) as direct_sum'_mat_equiv_morph. +Proof. intros; apply direct_sum'_simplify_mat_equiv; easy. Qed. + +Lemma ei_kron_split k n m : + @e_i (n*m) k = + e_i (k / m) ⊗ e_i (k mod m). +Proof. + apply mat_equiv_eq; auto with wf_db. + intros i j Hi Hj. + replace j with O by lia. + unfold e_i, kron. + do 2 simplify_bools_lia_one_kernel. + do 2 simplify_bools_moddy_lia_one_kernel. + rewrite Cmult_if_if_1_l. + apply f_equal_if; [|easy..]. + now rewrite andb_comm, <- eqb_iff_div_mod_eqb. +Qed. + +Lemma ei_kron_join k l n m : + (l < m)%nat -> + @e_i n k ⊗ e_i l = + @e_i (n*m) (k*m + l). +Proof. + intros Hl. + apply mat_equiv_eq; auto with wf_db. + intros i j Hi Hj. + replace j with O by lia. + unfold e_i, kron. + do 2 simplify_bools_lia_one_kernel. + do 2 simplify_bools_moddy_lia_one_kernel. + rewrite Cmult_if_if_1_l. + apply f_equal_if; [|easy..]. + symmetry. + rewrite (eqb_iff_div_mod_eqb m). + rewrite mod_add_l, Nat.div_add_l by lia. + rewrite (Nat.mod_small l m Hl), (Nat.div_small l m Hl). + now rewrite Nat.add_0_r, andb_comm. +Qed. + +Local Open Scope nat_scope. + +Lemma matrix_times_basis_eq_lt {m n : nat} (A : Matrix m n) (i j : nat) : + j < n -> (A × basis_vector n j) i 0 = A i j. +Proof. + intros Hj. + unfold Mmult. + rewrite (big_sum_eq_bounded _ (fun k => if k =? j then A i j else 0%R)%C). + 2: { + intros k Hk. + unfold basis_vector. + bdestructΩ'; lca. + } + rewrite big_sum_if_eq_C. + bdestructΩ'. +Qed. + +Lemma matrix_times_basis_mat_equiv {m n : nat} (A : Matrix m n) (j : nat) : + j < n -> mat_equiv (A × basis_vector n j) + (get_col A j). +Proof. + intros Hj i z Hi Hz. + replace z with 0 by lia. + rewrite matrix_times_basis_eq_lt by easy. + unfold get_col. + bdestructΩ'. +Qed. + +Lemma matrix_conj_basis_eq_lt {m n : nat} (A : Matrix m n) (i j : nat) : + i < m -> j < n -> ((basis_vector m i)⊤ × A × basis_vector n j) 0 0 = A i j. +Proof. + intros Hi Hj. + rewrite matrix_times_basis_mat_equiv by lia. + unfold get_col. + bdestructΩ'. + unfold Mmult, Matrix.transpose. + rewrite (big_sum_eq_bounded _ (fun k => if k =? i then A i j else 0%R)%C). + 2: { + intros k Hk. + unfold basis_vector. + bdestructΩ'; lca. + } + rewrite big_sum_if_eq_C. + bdestructΩ'. +Qed. + +Lemma mat_equiv_of_all_basis_conj {m n : nat} (A B : Matrix m n) + (H : forall (i j : nat), i < m -> j < n -> + ((basis_vector m i) ⊤ × A × basis_vector n j) 0 0 = + ((basis_vector m i) ⊤ × B × basis_vector n j) 0 0) : + mat_equiv A B. +Proof. + intros i j Hi Hj. + specialize (H i j Hi Hj). + now rewrite 2!matrix_conj_basis_eq_lt in H by easy. +Qed. + +Local Open Scope nat_scope. + +(** * Permutation matrices *) +Definition perm_mat n (p : nat -> nat) : Square n := + (fun x y => if (x =? p y) && (x if x =? p y then C1 else C0. +Proof. + intros i j Hi Hj. + unfold perm_mat. + bdestructΩ'. +Qed. + +Add Parametric Morphism n : (perm_mat n) with signature + perm_eq n ==> eq as perm_mat_perm_eq_to_eq_proper. +Proof. + intros f g Hfg. + apply mat_equiv_eq; auto with wf_db. + rewrite !perm_mat_defn. + intros i j Hi Hj. + now rewrite Hfg by easy. +Qed. + +Lemma perm_mat_id : forall n, + perm_mat n (Datatypes.id) = (I n). +Proof. + intros n. + apply mat_equiv_eq; auto with wf_db. + intros i j Hi Hj. + unfold Datatypes.id, perm_mat, I. + bdestructΩ'. +Qed. + +Lemma perm_mat_idn n : + perm_mat n idn = (I n). +Proof. + apply mat_equiv_eq; auto with wf_db. + intros i j Hi Hj. + unfold perm_mat, I. + bdestructΩ'. +Qed. + +#[export] Hint Rewrite perm_mat_idn : perm_cleanup_db. + +Lemma perm_mat_unitary : forall n p, + permutation n p -> WF_Unitary (perm_mat n p). +Proof. + intros n p [pinv Hp]. + split; [apply perm_mat_WF|]. + apply mat_equiv_eq; auto with wf_db. + intros i j Hi Hj. + unfold Mmult, adjoint, perm_mat, I. + do 2 simplify_bools_lia_one_kernel. + bdestruct (i =? j). + - subst. + apply big_sum_unique. + exists (p j). + destruct (Hp j) as [? _]; auto. + split; auto. + split; intros; bdestructΩ'; lca. + - apply (@big_sum_0 C C_is_monoid). + intros z. + bdestruct_all; simpl; try lca. + subst. + assert (pinv (p i) = pinv (p j)) by auto. + pose proof (fun x Hx => proj1 (proj2 (proj2 (Hp x Hx)))) as Hrw. + rewrite !Hrw in * by auto. + congruence. +Qed. + +Lemma perm_mat_Mmult n f g : + perm_bounded n g -> + perm_mat n f × perm_mat n g = perm_mat n (f ∘ g)%prg. +Proof. + intros Hg. + apply mat_equiv_eq; auto with wf_db. + intros i j Hi Hj. + unfold perm_mat, Mmult, compose. + do 2 simplify_bools_lia_one_kernel. + bdestruct (i =? f (g j)). + - apply big_sum_unique. + exists (g j). + specialize (Hg j Hj). + split; [now auto|]. + split; [bdestructΩ'; now Csimpl|]. + intros k Hk ?. + bdestructΩ'; now Csimpl. + - apply (@big_sum_0_bounded C). + intros k Hk. + bdestructΩ'; now Csimpl. +Qed. + +Lemma perm_mat_I : forall n f, + (forall x, x < n -> f x = x) -> + perm_mat n f = I n. +Proof. + intros n f Hinv. + apply mat_equiv_eq; auto with wf_db. + unfold perm_mat, I. + intros i j Hi Hj. + do 2 simplify_bools_lia_one_kernel. + now rewrite Hinv by easy. +Qed. + +Lemma perm_mat_col_swap : forall n f i j, + i < n -> j < n -> + perm_mat n (fswap f i j) = col_swap (perm_mat n f) i j. +Proof. + intros. + apply mat_equiv_eq; auto with wf_db. + intros k l Hk Hl. + unfold perm_mat, fswap, col_swap, I. + bdestructΩ'. +Qed. + +Lemma perm_mat_col_swap_I : forall n f i j, + (forall x, x < n -> f x = x) -> + i < n -> j < n -> + perm_mat n (fswap f i j) = col_swap (I n) i j. +Proof. + intros. + rewrite perm_mat_col_swap by easy. + now rewrite perm_mat_I by easy. +Qed. + + +Lemma perm_mat_row_swap : forall n f i j, + i < n -> j < n -> + perm_mat n (fswap f i j) = (row_swap (perm_mat n f)† i j)†. +Proof. + intros. + apply mat_equiv_eq; auto with wf_db. + intros k l Hk Hl. + unfold perm_mat, fswap, row_swap, I, adjoint. + do 3 simplify_bools_lia_one_kernel. + rewrite !(if_dist _ _ _ Cconj). + Csimpl. + bdestructΩ'. +Qed. + +Lemma perm_mat_e_i : forall n f i, + i < n -> + (perm_mat n f) × e_i i = e_i (f i). +Proof. + intros n f i Hi. + apply mat_equiv_eq; auto with wf_db. + unfold mat_equiv; intros k l Hk Hl. + replace l with 0 in * by lia. + unfold Mmult. + apply big_sum_unique. + exists i. + split; auto. + unfold e_i, perm_mat; + split; [bdestructΩ'_with ltac:(try lia; try lca)|]. + intros. + bdestructΩ'_with ltac:(try lia; try lca). +Qed. + +(* with get_entry_with_e_i this became soo much easier *) +Lemma perm_mat_conjugate : forall {n} (A : Square n) f (i j : nat), + WF_Matrix A -> + i < n -> j < n -> + perm_bounded n f -> + ((perm_mat n f)† × A × ((perm_mat n f))) i j = A (f i) (f j). +Proof. + intros. + rewrite get_entry_with_e_i, (get_entry_with_e_i A) + by auto with perm_bounded_db. + rewrite <- 2 Mmult_assoc, <- Mmult_adjoint. + rewrite perm_mat_e_i by auto with perm_bounded_db. + rewrite 3 Mmult_assoc. + rewrite perm_mat_e_i; auto. +Qed. + +Lemma perm_mat_conjugate_nonsquare : + forall {m n} (A : Matrix m n) f g (i j : nat), + WF_Matrix A -> + i < m -> j < n -> + perm_bounded m g -> perm_bounded n f -> + ((perm_mat m g)† × A × ((perm_mat n f))) i j = A (g i) (f j). +Proof. + intros. + rewrite get_entry_with_e_i, (get_entry_with_e_i A) by auto. + rewrite <- 2 Mmult_assoc, <- Mmult_adjoint. + rewrite perm_mat_e_i by auto. + rewrite 3 Mmult_assoc. + rewrite perm_mat_e_i; auto. +Qed. + +Lemma perm_mat_permutes_basis_vectors_r : forall n f k, (k < n)%nat -> + (perm_mat n f) × (basis_vector n k) = e_i (f k). +Proof. + intros n f k Hk. + rewrite basis_vector_eq_e_i by easy. + apply perm_mat_e_i; easy. +Qed. + +Lemma perm_mat_permutes_matrix_r : forall n m f (A : Matrix n m), + permutation n f -> + (perm_mat n f) × A ≡ (fun i j => A (perm_inv n f i) j). +Proof. + intros n m f A Hperm. + apply mat_equiv_of_equiv_on_ei. + intros k Hk. + rewrite Mmult_assoc, <- 2(matrix_by_basis _ _ Hk). + rewrite (vector_equiv_basis_comb _ (get_col _ _)). + rewrite Mmult_Msum_distr_l. + erewrite big_sum_eq_bounded. + 2: { + intros l Hl. + rewrite Mscale_mult_dist_r, perm_mat_e_i by easy. + reflexivity. + } + intros i j Hi Hj; replace j with O by lia; clear j Hj. + rewrite Msum_Csum. + unfold get_col, scale, e_i. + rewrite Nat.eqb_refl. + apply big_sum_unique. + exists (perm_inv n f i). + repeat split; auto with perm_bounded_db. + - rewrite (perm_inv_is_rinv_of_permutation n f Hperm i Hi), Nat.eqb_refl. + bdestructΩ'; now Csimpl. + - intros j Hj Hjne. + bdestruct (i =? f j); [|bdestructΩ'; now Csimpl]. + exfalso; apply Hjne. + apply (permutation_is_injective n f Hperm); auto with perm_bounded_db. + rewrite (perm_inv_is_rinv_of_permutation n f Hperm i Hi); easy. +Qed. + +Lemma perm_mat_equiv_of_perm_eq : forall n f g, + (perm_eq n f g) -> + perm_mat n f ≡ perm_mat n g. +Proof. + intros n f g Heq. + apply mat_equiv_of_equiv_on_ei. + intros k Hk. + rewrite 2!perm_mat_e_i, Heq by easy. + easy. +Qed. + +#[export] Hint Resolve perm_mat_equiv_of_perm_eq : perm_inv_db. + +Lemma perm_mat_eq_of_perm_eq : forall n f g, + (perm_eq n f g) -> + perm_mat n f = perm_mat n g. +Proof. + intros. + apply mat_equiv_eq; auto with wf_db. + now apply perm_mat_equiv_of_perm_eq. +Qed. + +#[export] Hint Resolve perm_mat_eq_of_perm_eq : perm_inv_db. + +Lemma perm_mat_equiv_of_perm_eq' : forall n m f g, n = m -> + (perm_eq n f g) -> + perm_mat n f ≡ perm_mat m g. +Proof. + intros; subst n; apply perm_mat_equiv_of_perm_eq; easy. +Qed. + +Lemma perm_mat_transpose {n f} (Hf : permutation n f) : + (perm_mat n f) ⊤ ≡ perm_mat n (perm_inv n f). +Proof. + intros i j Hi Hj. + unfold "⊤". + unfold perm_mat. + simplify_bools_lia. + rewrite <- (@perm_inv_eqb_iff n f) by cleanup_perm. + now rewrite Nat.eqb_sym. +Qed. + +Lemma perm_mat_transpose_eq {n f} (Hf : permutation n f) : + (perm_mat n f) ⊤ = perm_mat n (perm_inv n f). +Proof. + apply mat_equiv_eq; auto with wf_db. + now apply perm_mat_transpose. +Qed. + +Lemma matrix_by_basis_perm_eq {n m} (A : Matrix n m) (i : nat) (Hi : i < m) : + get_col A i ≡ A × e_i i. +Proof. + intros k l Hk Hl. + replace l with 0 by lia. + unfold get_col. + simplify_bools_lia_one_kernel. + symmetry. + unfold Mmult, e_i. + simplify_bools_lia_one_kernel. + apply big_sum_unique. + exists i. + split; auto. + do 2 simplify_bools_lia_one_kernel. + split; intros; + simplify_bools_lia; + now Csimpl. +Qed. + +Lemma perm_mat_permutes_matrix_l : forall n m f (A : Matrix n m), + perm_bounded m f -> + A × (perm_mat m f) ≡ (fun i j => A i (f j)). +Proof. + intros n m f A Hf. + apply mat_equiv_of_equiv_on_ei. + intros k Hk. + rewrite Mmult_assoc, perm_mat_e_i, <- (matrix_by_basis _ _ Hk) by easy. + rewrite <- matrix_by_basis_perm_eq by auto with perm_bounded_db. + easy. +Qed. + +Lemma perm_mat_permutes_matrix_l_eq : forall n m f (A : Matrix n m), + WF_Matrix A -> + perm_bounded m f -> + A × (perm_mat m f) = make_WF (fun i j => A i (f j)). +Proof. + intros n m f A HA Hf. + apply mat_equiv_eq; auto with wf_db. + rewrite make_WF_equiv. + now apply perm_mat_permutes_matrix_l. +Qed. + +Lemma perm_mat_permutes_matrix_r_eq : forall n m f (A : Matrix n m), + WF_Matrix A -> + permutation n f -> + (perm_mat n f) × A = make_WF (fun i j => A (perm_inv n f i) j). +Proof. + intros n m f A HA Hf. + apply mat_equiv_eq; auto with wf_db. + rewrite make_WF_equiv. + now apply perm_mat_permutes_matrix_r. +Qed. + +Lemma perm_mat_perm_eq_idn n f : + perm_eq n f idn -> + perm_mat n f = I n. +Proof. + intros ->. + apply perm_mat_idn. +Qed. + +Lemma perm_mat_transpose_rinv {n f} (Hf : permutation n f) : + (perm_mat n f) × (perm_mat n f) ⊤ = I n. +Proof. + rewrite perm_mat_transpose_eq by easy. + rewrite perm_mat_Mmult by auto with perm_db. + cleanup_perm. +Qed. + +Lemma perm_mat_transpose_linv {n f} (Hf : permutation n f) : + (perm_mat n f) ⊤ × (perm_mat n f) = I n. +Proof. + rewrite perm_mat_transpose_eq by easy. + rewrite perm_mat_Mmult by auto with perm_db. + cleanup_perm. +Qed. + +Lemma perm_mat_of_stack_perms n0 n1 f g : + perm_bounded n0 f -> perm_bounded n1 g -> + perm_mat (n0 + n1) (stack_perms n0 n1 f g) = + direct_sum' (perm_mat n0 f) (perm_mat n1 g). +Proof. + intros Hf Hg. + apply mat_equiv_eq; auto with wf_db. + apply mat_equiv_of_equiv_on_ei. + intros k Hk. + rewrite perm_mat_e_i by easy. + rewrite 2!ei_direct_sum_split. + rewrite Mmult_if_r. + rewrite (direct_sum'_Mmult _ _ (e_i k) (Zero)). + rewrite (direct_sum'_Mmult _ _ (@Zero n0 0) (e_i (k - n0))). + rewrite 2!Mmult_0_r. + bdestruct (k + perm_mat (n0 * n1) (tensor_perms n0 n1 f g) = + perm_mat n0 f ⊗ perm_mat n1 g. +Proof. + intros Hg. + apply mat_equiv_eq; auto with wf_db. + apply mat_equiv_of_equiv_on_ei. + intros k Hk. + rewrite perm_mat_e_i by easy. + symmetry. + rewrite ei_kron_split. + restore_dims. + rewrite kron_mixed_product. + unfold tensor_perms. + simplify_bools_lia_one_kernel. + rewrite 2!perm_mat_e_i by show_moddy_lt. + now rewrite ei_kron_join by cleanup_perm. +Qed. + +Lemma perm_mat_inj_mat_equiv n f g + (Hf : perm_bounded n f) (Hg : perm_bounded n g) : + perm_mat n f ≡ perm_mat n g -> + perm_eq n f g. +Proof. + intros Hequiv. + intros i Hi. + generalize (Hequiv (f i) i (Hf i Hi) Hi). + unfold perm_mat. + pose proof (Hf i Hi). + pose proof C1_nonzero. + bdestructΩ'. +Qed. + +Lemma perm_mat_inj n f g + (Hf : perm_bounded n f) (Hg : perm_bounded n g) : + perm_mat n f = perm_mat n g -> + perm_eq n f g. +Proof. + rewrite <- mat_equiv_eq_iff by auto with wf_db. + now apply perm_mat_inj_mat_equiv. +Qed. + +Lemma perm_mat_determinant_sqr n f (Hf : permutation n f) : + (Determinant (perm_mat n f) ^ 2)%C = 1%R. +Proof. + simpl. + Csimpl. + rewrite Determinant_transpose at 1. + rewrite Determinant_multiplicative. + rewrite perm_mat_transpose_linv by easy. + now rewrite Det_I. +Qed. + +Lemma perm_mat_perm_eq_of_proportional n f g : + (exists c, perm_mat n f = c .* perm_mat n g /\ c <> 0%R) -> + perm_bounded n f -> + perm_eq n f g. +Proof. + intros (c & Heq & Hc) Hf. + rewrite <- mat_equiv_eq_iff in Heq by auto with wf_db. + intros i Hi. + pose proof (Hf i Hi) as Hfi. + generalize (Heq (f i) i Hfi Hi). + unfold perm_mat, scale. + do 3 simplify_bools_lia_one_kernel. + rewrite Cmult_if_1_r. + pose proof C1_nonzero. + bdestructΩ'. +Qed. + +Lemma perm_mat_eq_of_proportional n f g : + (exists c, perm_mat n f = c .* perm_mat n g /\ c <> 0%R) -> + perm_bounded n f -> + perm_mat n f = perm_mat n g. +Proof. + intros H Hf. + apply perm_mat_eq_of_perm_eq. + now apply perm_mat_perm_eq_of_proportional. +Qed. + +Lemma Mmult_perm_mat_l n m (A B : Matrix n m) + (HA : WF_Matrix A) (HB : WF_Matrix B) f (Hf : permutation n f) : + perm_mat n f × A = B <-> A = perm_mat n (perm_inv n f) × B. +Proof. + rewrite <- perm_mat_transpose_eq by auto. + split; [intros <- | intros ->]; + now rewrite <- Mmult_assoc, 1?perm_mat_transpose_rinv, + 1?perm_mat_transpose_linv, Mmult_1_l by auto. +Qed. + +Lemma Mmult_perm_mat_l' n m (A B : Matrix n m) + (HA : WF_Matrix A) (HB : WF_Matrix B) f (Hf : permutation n f) : + B = perm_mat n f × A <-> perm_mat n (perm_inv n f) × B = A. +Proof. + split; intros H; symmetry; + apply Mmult_perm_mat_l; auto. +Qed. + +Lemma Mmult_perm_mat_r n m (A B : Matrix n m) + (HA : WF_Matrix A) (HB : WF_Matrix B) f (Hf : permutation m f) : + A × perm_mat m f = B <-> A = B × perm_mat m (perm_inv m f). +Proof. + rewrite <- perm_mat_transpose_eq by auto. + split; [intros <- | intros ->]; + now rewrite Mmult_assoc, 1?perm_mat_transpose_rinv, + 1?perm_mat_transpose_linv, Mmult_1_r by auto. +Qed. + +Lemma Mmult_perm_mat_r' n m (A B : Matrix n m) + (HA : WF_Matrix A) (HB : WF_Matrix B) f (Hf : permutation m f) : + B = A × perm_mat m f <-> B × perm_mat m (perm_inv m f) = A. +Proof. + split; intros H; symmetry; + apply Mmult_perm_mat_r; auto. +Qed. + +(** Transform a (0,...,n-1) permutation into a 2^n by 2^n matrix. *) +Definition perm_to_matrix n p := + perm_mat (2 ^ n) (qubit_perm_to_nat_perm n p). + +Lemma perm_to_matrix_WF : forall n p, WF_Matrix (perm_to_matrix n p). +Proof. intros. apply perm_mat_WF. Qed. +#[export] Hint Resolve perm_to_matrix_WF : wf_db. + +Add Parametric Morphism n : (perm_to_matrix n) with signature + perm_eq n ==> eq as perm_to_matrix_perm_eq_to_eq_proper. +Proof. + intros f g Hfg. + unfold perm_to_matrix. + now rewrite Hfg. +Qed. + +Lemma perm_to_matrix_permutes_qubits : forall n p f, + perm_bounded n p -> + perm_to_matrix n p × f_to_vec n f = f_to_vec n (fun x => f (p x)). +Proof. + intros n p f Hp. + rewrite 2 basis_f_to_vec. + rewrite !basis_vector_eq_e_i by apply funbool_to_nat_bound. + unfold perm_to_matrix. + rewrite perm_mat_e_i by apply funbool_to_nat_bound. + f_equal. + rewrite qubit_perm_to_nat_perm_defn by apply funbool_to_nat_bound. + apply funbool_to_nat_eq. + intros i Hi. + unfold compose. + now rewrite funbool_to_nat_inverse by auto. +Qed. + +Lemma perm_to_matrix_unitary : forall n p, + permutation n p -> + WF_Unitary (perm_to_matrix n p). +Proof. + intros. + apply perm_mat_unitary. + auto with perm_db. +Qed. + + +Lemma Private_perm_to_matrix_Mmult : forall n f g, + permutation n f -> permutation n g -> + perm_to_matrix n f × perm_to_matrix n g = perm_to_matrix n (g ∘ f)%prg. +Proof. + intros n f g Hf Hg. + unfold perm_to_matrix. + rewrite perm_mat_Mmult by auto with perm_bounded_db. + now rewrite qubit_perm_to_nat_perm_compose by auto with perm_bounded_db. +Qed. + +#[deprecated(note="Use perm_to_matrix_compose instead")] +Notation perm_to_matrix_Mmult := Private_perm_to_matrix_Mmult. + +Lemma perm_to_matrix_I : forall n f, + (forall x, x < n -> f x = x) -> + perm_to_matrix n f = I (2 ^ n). +Proof. + intros n f Hf. + unfold perm_to_matrix. + apply perm_mat_I. + intros x Hx. + unfold qubit_perm_to_nat_perm, compose. + erewrite funbool_to_nat_eq. + 2: { intros y Hy. rewrite Hf by assumption. reflexivity. } + simplify_bools_lia_one_kernel. + apply nat_to_funbool_inverse. + assumption. +Qed. + +Lemma perm_to_matrix_perm_eq n f g : + perm_eq n f g -> + perm_to_matrix n f ≡ perm_to_matrix n g. +Proof. + intros Hfg. + apply perm_mat_equiv_of_perm_eq. + now rewrite Hfg. +Qed. + +#[export] Hint Resolve perm_to_matrix_perm_eq : perm_inv_db. + +Lemma perm_to_matrix_eq_of_perm_eq n f g : + perm_eq n f g -> + perm_to_matrix n f = perm_to_matrix n g. +Proof. + intros Hfg. + apply mat_equiv_eq; auto with wf_db. + now apply perm_to_matrix_perm_eq. +Qed. + +#[export] Hint Resolve perm_to_matrix_eq_of_perm_eq : perm_inv_db. + +Lemma perm_to_matrix_transpose {n f} (Hf : permutation n f) : + (perm_to_matrix n f) ⊤ ≡ perm_to_matrix n (perm_inv n f). +Proof. + unfold perm_to_matrix. + rewrite perm_mat_transpose by auto with perm_db. + cleanup_perm_inv. +Qed. + +Lemma perm_to_matrix_transpose_eq {n f} (Hf : permutation n f) : + (perm_to_matrix n f) ⊤ = perm_to_matrix n (perm_inv n f). +Proof. + apply mat_equiv_eq; auto with wf_db. + now apply perm_to_matrix_transpose. +Qed. + +Lemma perm_to_matrix_transpose' {n f} (Hf : permutation n f) : + (perm_to_matrix n f) ⊤ ≡ perm_to_matrix n (perm_inv' n f). +Proof. + rewrite perm_to_matrix_transpose by easy. + cleanup_perm. +Qed. + +Lemma perm_to_matrix_transpose_eq' {n f} (Hf : permutation n f) : + (perm_to_matrix n f) ⊤ = perm_to_matrix n (perm_inv' n f). +Proof. + apply mat_equiv_eq; auto with wf_db. + now apply perm_to_matrix_transpose'. +Qed. + +Lemma perm_to_matrix_transpose_linv {n f} (Hf : permutation n f) : + (perm_to_matrix n f) ⊤ × perm_to_matrix n f = I (2 ^ n). +Proof. + apply perm_mat_transpose_linv; auto with perm_db. +Qed. + +Lemma perm_to_matrix_transpose_rinv {n f} (Hf : permutation n f) : + perm_to_matrix n f × (perm_to_matrix n f) ⊤ = I (2 ^ n). +Proof. + apply perm_mat_transpose_rinv; auto with perm_db. +Qed. + +Lemma Mmult_perm_to_matrix_l n m (A B : Matrix (2 ^ n) m) + (HA : WF_Matrix A) (HB : WF_Matrix B) f (Hf : permutation n f) : + perm_to_matrix n f × A = B <-> A = perm_to_matrix n (perm_inv n f) × B. +Proof. + rewrite <- perm_to_matrix_transpose_eq by auto. + unfold perm_to_matrix. + split; [intros <- | intros ->]; + now rewrite <- Mmult_assoc, 1?perm_mat_transpose_rinv, + 1?perm_mat_transpose_linv, Mmult_1_l by auto with perm_db. +Qed. + +Lemma Mmult_perm_to_matrix_l' n m (A B : Matrix (2 ^ n) m) + (HA : WF_Matrix A) (HB : WF_Matrix B) f (Hf : permutation n f) : + B = perm_to_matrix n f × A <-> perm_to_matrix n (perm_inv n f) × B = A. +Proof. + split; intros H; symmetry; apply Mmult_perm_to_matrix_l; + auto. +Qed. + +Lemma Mmult_perm_to_matrix_r n m (A B : Matrix n (2 ^ m)) + (HA : WF_Matrix A) (HB : WF_Matrix B) f (Hf : permutation m f) : + A × perm_to_matrix m f = B <-> A = B × perm_to_matrix m (perm_inv m f). +Proof. + rewrite <- perm_to_matrix_transpose_eq by auto. + unfold perm_to_matrix. + split; [intros <- | intros ->]; + now rewrite Mmult_assoc, 1?perm_mat_transpose_rinv, + 1?perm_mat_transpose_linv, Mmult_1_r by auto with perm_db. +Qed. + +Lemma Mmult_perm_to_matrix_r' n m (A B : Matrix n (2 ^ m)) + (HA : WF_Matrix A) (HB : WF_Matrix B) f (Hf : permutation m f) : + B = A × perm_to_matrix m f <-> B × perm_to_matrix m (perm_inv m f) = A. +Proof. + split; intros H; symmetry; apply Mmult_perm_to_matrix_r; + auto. +Qed. + +Lemma perm_to_matrix_permutes_qubits_l n p f + (Hp : permutation n p) : + (f_to_vec n f) ⊤ × perm_to_matrix n p = + (f_to_vec n (fun x => f (perm_inv n p x))) ⊤. +Proof. + rewrite <- (transpose_involutive _ _ (perm_to_matrix _ _)). + rewrite <- Mmult_transpose. + rewrite perm_to_matrix_transpose_eq by easy. + f_equal. + apply perm_to_matrix_permutes_qubits. + auto with perm_bounded_db. +Qed. + +#[export] Hint Resolve perm_to_matrix_perm_eq + perm_to_matrix_eq_of_perm_eq : perm_inv_db. + +Lemma perm_to_matrix_of_stack_perms n0 n1 f g + (Hf : permutation n0 f) (Hg : permutation n1 g) : + perm_to_matrix (n0 + n1) (stack_perms n0 n1 f g) = + perm_to_matrix n0 f ⊗ perm_to_matrix n1 g. +Proof. + unfold perm_to_matrix. + rewrite <- perm_mat_of_tensor_perms by cleanup_perm. + rewrite <- Nat.pow_add_r. + cleanup_perm. +Qed. + +#[export] Hint Rewrite perm_to_matrix_of_stack_perms : perm_cleanup_db. + +Lemma perm_to_matrix_of_stack_perms' n0 n1 n01 f g + (Hf : permutation n0 f) (Hg : permutation n1 g) + (Hn01 : n0 + n1 = n01) : + perm_to_matrix n01 (stack_perms n0 n1 f g) = + perm_to_matrix n0 f ⊗ perm_to_matrix n1 g. +Proof. + subst. + now apply perm_to_matrix_of_stack_perms. +Qed. + +Lemma perm_to_matrix_idn n : + perm_to_matrix n idn = I (2^n). +Proof. + rewrite <- perm_mat_idn. + apply perm_mat_eq_of_perm_eq. + cleanup_perm_inv. +Qed. + +Lemma perm_to_matrix_compose n f g : + perm_bounded n f -> perm_bounded n g -> + perm_to_matrix n (f ∘ g) = + perm_to_matrix n g × perm_to_matrix n f. +Proof. + intros Hf Hg. + symmetry. + unfold perm_to_matrix. + rewrite <- qubit_perm_to_nat_perm_compose by easy. + apply perm_mat_Mmult. + auto with perm_bounded_db. +Qed. + +#[export] Hint Rewrite perm_to_matrix_compose : perm_cleanup_db. + +Lemma perm_to_matrix_inj_mat_equiv n f g + (Hf : perm_bounded n f) (Hg : perm_bounded n g) : + perm_to_matrix n f ≡ perm_to_matrix n g -> + perm_eq n f g. +Proof. + intros Hequiv. + apply qubit_perm_to_nat_perm_inj; [easy|]. + apply perm_mat_inj_mat_equiv; [auto with perm_bounded_db..|]. + exact Hequiv. +Qed. + +Lemma perm_to_matrix_inj n f g + (Hf : perm_bounded n f) (Hg : perm_bounded n g) : + perm_to_matrix n f = perm_to_matrix n g -> + perm_eq n f g. +Proof. + rewrite <- mat_equiv_eq_iff by auto with wf_db. + now apply perm_to_matrix_inj_mat_equiv. +Qed. + + +Lemma perm_to_matrix_perm_eq_of_proportional n f g : + (exists c, perm_to_matrix n f = + c .* perm_to_matrix n g /\ c <> 0%R) -> + perm_bounded n f -> + perm_eq n f g. +Proof. + intros H Hf. + pose proof (perm_mat_perm_eq_of_proportional _ _ _ H). + apply qubit_perm_to_nat_perm_inj; auto with perm_bounded_db. +Qed. + +Lemma perm_to_matrix_eq_of_proportional n f g : + (exists c, perm_to_matrix n f = + c .* perm_to_matrix n g /\ c <> 0%R) -> + perm_bounded n f -> + perm_to_matrix n f = perm_to_matrix n g. +Proof. + intros H Hf. + apply perm_to_matrix_eq_of_perm_eq. + now apply perm_to_matrix_perm_eq_of_proportional. +Qed. + +Lemma kron_comm_pows2_eq_perm_to_matrix_big_swap n o : + kron_comm (2^o) (2^n) = perm_to_matrix (n + o) (big_swap_perm o n). +Proof. + symmetry. + apply equal_on_basis_states_implies_equal; + [|rewrite WF_Matrix_dim_change_iff by show_pow2_le |]; + [auto with wf_db..|]. + intros f. + rewrite perm_to_matrix_permutes_qubits by auto with perm_db. + rewrite (f_to_vec_split'_eq _ _ f). + restore_dims. + rewrite kron_comm_commutes_vectors_l by auto with wf_db. + rewrite Nat.add_comm, f_to_vec_split'_eq. + f_equal; apply f_to_vec_eq; intros i Hi; f_equal; + unfold big_swap_perm; bdestructΩ'. +Qed. + +Lemma kron_comm_pows2_eq_perm_to_matrix_rotr n o : + kron_comm (2^o) (2^n) = perm_to_matrix (n + o) (rotr (n + o) n). +Proof. + rewrite kron_comm_pows2_eq_perm_to_matrix_big_swap. + now rewrite big_swap_perm_eq_rotr, Nat.add_comm. +Qed. + +Lemma kron_comm_eq_perm_mat_of_kron_comm_perm p q : + kron_comm p q = perm_mat (p * q) (kron_comm_perm p q). +Proof. + apply mat_equiv_eq; auto using WF_Matrix_dim_change with wf_db zarith. + apply mat_equiv_of_equiv_on_ei. + intros k Hk. + rewrite (Nat.div_mod_eq k p) at 1. + rewrite (Nat.mul_comm p (k/p)), (Nat.mul_comm q p). + rewrite <- (kron_e_i_e_i p q) at 1 by show_moddy_lt. + restore_dims. + rewrite kron_comm_commutes_vectors_l by auto with wf_db. + rewrite perm_mat_e_i by show_moddy_lt. + rewrite (kron_e_i_e_i q p) by show_moddy_lt. + rewrite Nat.mul_comm. + unfold kron_comm_perm. + bdestructΩ'. +Qed. + +Lemma perm_to_matrix_rotr_eq_kron_comm : forall n o, + perm_to_matrix (n + o) (rotr (n + o) n) = kron_comm (2^o) (2^n). +Proof. + intros n o. + now rewrite <- kron_comm_pows2_eq_perm_to_matrix_rotr. +Qed. + +#[export] Hint Rewrite perm_to_matrix_rotr_eq_kron_comm : perm_inv_db. + +Lemma perm_to_matrix_rotr_eq_kron_comm_alt : forall n o, + perm_to_matrix (n + o) (rotr (n + o) o) = kron_comm (2^n) (2^o). +Proof. + intros n o. + rewrite Nat.add_comm. + cleanup_perm_inv. +Qed. + +#[export] Hint Rewrite perm_to_matrix_rotr_eq_kron_comm_alt : perm_inv_db. + +Lemma perm_to_matrix_rotr_eq_kron_comm_mat_equiv : forall n o, + perm_to_matrix (n + o) (rotr (n + o) n) ≡ kron_comm (2^o) (2^n). +Proof. + intros n o. + now rewrite perm_to_matrix_rotr_eq_kron_comm. +Qed. + +#[export] Hint Resolve + perm_to_matrix_rotr_eq_kron_comm_mat_equiv : perm_inv_db. + +Lemma perm_to_matrix_rotl_eq_kron_comm : forall n o, + perm_to_matrix (n + o) (rotl (n + o) n) = kron_comm (2^n) (2^o). +Proof. + intros n o. + rewrite <- (perm_to_matrix_eq_of_perm_eq _ _ _ (rotr_inv (n + o) n)). + rewrite <- perm_to_matrix_transpose_eq by auto with perm_db. + rewrite perm_to_matrix_rotr_eq_kron_comm. + apply kron_comm_transpose. +Qed. + +#[export] Hint Rewrite perm_to_matrix_rotl_eq_kron_comm : perm_inv_db. + +Lemma perm_to_matrix_rotl_eq_kron_comm_mat_equiv : forall n o, + perm_to_matrix (n + o) (rotl (n + o) n) ≡ kron_comm (2^n) (2^o). +Proof. + intros. + now rewrite perm_to_matrix_rotl_eq_kron_comm. +Qed. + +#[export] Hint Resolve + perm_to_matrix_rotl_eq_kron_comm_mat_equiv : perm_inv_db. + +Lemma perm_to_matrix_rotr_commutes_kron_mat_equiv {n m p q} + (A : Matrix (2^n) (2^m)) (B : Matrix (2^p) (2^q)) : + @Mmult (2^n*2^p) (2^m*2^q) (2^q*2^m) + (A ⊗ B) (perm_to_matrix (q + m) (rotr (q + m) q)) ≡ + @Mmult (2^n*2^p) (2^p*2^n) (2^q*2^m) + (perm_to_matrix (p + n) (rotr (p + n) p)) (B ⊗ A). +Proof. + unify_pows_two. + rewrite 2!perm_to_matrix_rotr_eq_kron_comm. + restore_dims. + pose proof (kron_comm_commutes_r_mat_equiv (2^n) (2^m) + (2^p) (2^q) A B) as H. + apply H. +Qed. + +Lemma perm_to_matrix_rotr_commutes_kron {n m p q} + (A : Matrix (2^n) (2^m)) (B : Matrix (2^p) (2^q)) : + WF_Matrix A -> WF_Matrix B -> + @Mmult (2^n*2^p) (2^m*2^q) (2^q*2^m) + (A ⊗ B) (perm_to_matrix (q + m) (rotr (q + m) q)) = + @Mmult (2^n*2^p) (2^p*2^n) (2^q*2^m) + (perm_to_matrix (p + n) (rotr (p + n) p)) (B ⊗ A). +Proof. + unify_pows_two. + rewrite 2!perm_to_matrix_rotr_eq_kron_comm. + restore_dims. + pose proof (kron_comm_commutes_r (2^n) (2^m) + (2^p) (2^q) A B) as H. + rewrite !Nat.pow_add_r. + apply H. +Qed. + + +Lemma perm_to_matrix_swap_block_perm_natural {padl padm padr a} + (A : Matrix (2^a) (2^a)) : + @mat_equiv (2^padl*2^a*2^padm*2^a*2^padr) (2^padl*2^a*2^padm*2^a*2^padr) + (@Mmult _ (2^padl*2^a*2^padm*2^a*2^padr) _ + (I (2^padl) ⊗ A ⊗ I (2^padm * 2^a * 2^padr)) + (perm_to_matrix (padl + a + padm + a + padr) + (swap_block_perm padl padm a))) + (@Mmult _ (2^padl*2^a*2^padm*2^a*2^padr) _ + (perm_to_matrix (padl + a + padm + a + padr) + (swap_block_perm padl padm a)) + (I (2^padl * 2^a * 2^padm) ⊗ A ⊗ I (2^padr))). +Proof. + apply mat_equiv_of_all_basis_conj. + intros i j Hi Hj. + rewrite !Mmult_assoc. + rewrite <- !Nat.pow_add_r in *. + rewrite !basis_f_to_vec_alt by easy. + rewrite perm_to_matrix_permutes_qubits by cleanup_perm. + rewrite <- (transpose_involutive _ _ + (perm_to_matrix _ (swap_block_perm _ _ _))). + rewrite <- !Mmult_assoc, <- Mmult_transpose. + rewrite (perm_to_matrix_transpose_eq + (swap_block_perm_permutation padl padm padr a)). + rewrite (perm_to_matrix_eq_of_perm_eq _ _ _ + (swap_block_perm_inv padl padm padr a)). + rewrite perm_to_matrix_permutes_qubits by cleanup_perm. + replace (padl+a+padm+a+padr) with (padl+a+(padm+a+padr)) in * by lia. + rewrite 2!(f_to_vec_split'_eq (padl+a)), 2!(f_to_vec_split'_eq (padl)). + rewrite !(fun x y => kron_transpose' _ _ x y). + rewrite !(fun x y z => kron_mixed_product' _ _ _ _ _ _ _ x y z) by + (now rewrite ?Nat.pow_add_r; simpl;lia). + rewrite !Mmult_1_r by auto with wf_db. + symmetry. + + replace (padl+a+(padm+a+padr)) with ((padl+a+padm)+a+padr) in * by lia. + rewrite 2!(f_to_vec_split'_eq (padl+a+padm+a)), 2!(f_to_vec_split'_eq (_+_+_)). + rewrite !(fun x y => kron_transpose' _ _ x y). + rewrite !(fun x y z => kron_mixed_product' _ _ _ _ _ _ _ x y z) by + (now rewrite ?Nat.pow_add_r; simpl;lia). + rewrite !Mmult_1_r by auto with wf_db. + unfold kron. + rewrite !Nat.mod_1_r, Nat.Div0.div_0_l. + rewrite !basis_f_to_vec. + rewrite !basis_trans_basis. + rewrite !matrix_conj_basis_eq_lt + by show_moddy_lt. + rewrite !Cmult_if_1_l, !Cmult_if_if_1_r. + apply f_equal_if. + - do 4 simplify_bools_moddy_lia_one_kernel. + apply eq_iff_eq_true. + rewrite !andb_true_iff, !Nat.eqb_eq. + rewrite <- !funbool_to_nat_eq_iff. + split;intros [Hlow Hhigh]; + split. + + intros k Hk. + generalize (Hlow k ltac:(lia)). + unfold swap_block_perm. + now simplify_bools_lia. + + intros k Hk. + unfold swap_block_perm. + simplify_bools_lia. + bdestructΩ'. + * generalize (Hlow (padl+a+k) ltac:(lia)). + unfold swap_block_perm. + now simplify_bools_lia. + * generalize (Hlow (padl + a + k - (a + padm)) ltac:(lia)). + unfold swap_block_perm. + simplify_bools_lia. + intros <-. + f_equal; lia. + * apply_with_obligations + (Hhigh ((padl + a + k) - (padl + a + padm + a)) ltac:(lia)); + f_equal; [lia|]. + unfold swap_block_perm; bdestructΩ'. + + intros k Hk. + unfold swap_block_perm. + simplify_bools_lia. + bdestructΩ'. + * generalize (Hlow (k) ltac:(lia)). + unfold swap_block_perm. + now simplify_bools_lia. + * apply_with_obligations + (Hhigh ((a + padm) + k - (padl + a)) ltac:(lia)); + f_equal; [|lia]. + unfold swap_block_perm; bdestructΩ'. + * apply_with_obligations + (Hhigh (k - (padl + a)) ltac:(lia)); + f_equal; [|lia]. + unfold swap_block_perm; bdestructΩ'. + + intros k Hk. + apply_with_obligations (Hhigh (padm + a + k) ltac:(lia)); + f_equal; + unfold swap_block_perm; + bdestructΩ'. + - f_equal; + apply Bits.funbool_to_nat_eq; + intros; + unfold swap_block_perm; + bdestructΩ'; f_equal; lia. + - easy. +Qed. + +Lemma perm_to_matrix_swap_block_perm_natural_eq {padl padm padr a} + (A : Matrix (2^a) (2^a)) (HA : WF_Matrix A) : + @eq (Matrix (2^padl*2^a*2^padm*2^a*2^padr) (2^padl*2^a*2^padm*2^a*2^padr)) + (@Mmult _ (2^padl*2^a*2^padm*2^a*2^padr) _ + (I (2^padl) ⊗ A ⊗ I (2^padm * 2^a * 2^padr)) + (perm_to_matrix (padl + a + padm + a + padr) + (swap_block_perm padl padm a))) + (@Mmult _ (2^padl*2^a*2^padm*2^a*2^padr) _ + (perm_to_matrix (padl + a + padm + a + padr) + (swap_block_perm padl padm a)) + (I (2^padl * 2^a * 2^padm) ⊗ A ⊗ I (2^padr))). +Proof. + apply mat_equiv_eq; + auto using WF_Matrix_dim_change with wf_db. + apply perm_to_matrix_swap_block_perm_natural. +Qed. + +Lemma perm_to_matrix_swap_block_perm_natural_eq_alt {padl padm padr a} + (A : Matrix (2^a) (2^a)) (HA : WF_Matrix A) : + @eq (Matrix (2^padl*2^a*2^padm*2^a*2^padr) (2^(padl+a+padm+a+padr))) + (@Mmult _ (2^padl*2^a*2^padm*2^a*2^padr) _ + (I (2^padl) ⊗ A ⊗ I (2^padm * 2^a * 2^padr)) + (perm_to_matrix (padl + a + padm + a + padr) + (swap_block_perm padl padm a))) + (@Mmult (2^padl*2^a*2^padm*2^a*2^padr) (2^padl*2^a*2^padm*2^a*2^padr) _ + (perm_to_matrix (padl + a + padm + a + padr) + (swap_block_perm padl padm a)) + (I (2^padl * 2^a * 2^padm) ⊗ A ⊗ I (2^padr))). +Proof. + generalize (@perm_to_matrix_swap_block_perm_natural_eq + padl padm padr a A HA). + unify_pows_two. + easy. +Qed. + +Lemma perm_to_matrix_swap_block_perm_eq padl padm padr a : + perm_to_matrix (padl + a + padm + a + padr) + (swap_block_perm padl padm a) = + I (2^padl) ⊗ + (kron_comm (2^a) (2^padm * 2^a) × + (kron_comm (2^padm) (2^a) ⊗ I (2^a))) ⊗ + I (2^padr). +Proof. + rewrite (swap_block_perm_decomp_eq padl padr padm a). + rewrite <- !(Nat.add_assoc padl). + rewrite 2!perm_to_matrix_of_stack_perms by auto with perm_db. + rewrite perm_to_matrix_compose by auto with perm_db. + rewrite perm_to_matrix_of_stack_perms by auto with perm_db. + rewrite 3!perm_to_matrix_idn. + rewrite kron_assoc by auto with wf_db. + f_equal; [show_pow2_le..|]. + f_equal; [show_pow2_le..|]. + rewrite 2!perm_to_matrix_rotr_eq_kron_comm. + unify_pows_two. + rewrite (Nat.add_comm a padm). + easy. +Qed. + +#[export] Hint Rewrite perm_to_matrix_swap_block_perm_eq : perm_inv_db. + +Lemma perm_to_matrix_pullthrough_middle_eq_idn padl padm padr padm' f + (Hf : permutation (padl + padm + padr) f) + (Hfid : perm_eq_id_mid padl padm f) + (A : Matrix (2^padm') (2^padm)) (HA : WF_Matrix A) : + @Mmult (2^padl*2^padm'*2^padr) (2^padl*2^padm*2^padr) (2^(padl+padm+padr)) + (I (2^padl) ⊗ A ⊗ I (2^padr)) (perm_to_matrix (padl + padm + padr) f) = + @Mmult (2^(padl+padm'+padr)) (2^padl*2^padm'*2^padr) (2^padl*2^padm*2^padr) + (perm_to_matrix (padl + padm' + padr) + (expand_perm_id_mid padl padm' padr + (contract_perm_id_mid padl padm padr f))) + (I (2^padl) ⊗ A ⊗ I (2^padr)). +Proof. + rewrite (perm_to_matrix_eq_of_perm_eq _ _ _ + (perm_eq_sym (expand_contract_perm_perm_eq_idn_inv Hf Hfid))). + unfold expand_perm_id_mid. + rewrite 4!perm_to_matrix_compose by + (eapply perm_bounded_change_dims; auto with perm_db zarith + || apply compose_perm_bounded; + eapply perm_bounded_change_dims; auto with perm_db zarith). + (* replace (padl + padm + padr) with (padl + padr + padm) by lia. *) + rewrite !perm_to_matrix_of_stack_perms' by auto with perm_db zarith. + rewrite !perm_to_matrix_idn. + rewrite !perm_to_matrix_rotr_eq_kron_comm_alt, + !perm_to_matrix_rotr_eq_kron_comm. + unify_pows_two. + rewrite <- !Mmult_assoc. + rewrite !Nat.pow_add_r. + rewrite kron_assoc, <- 2!Nat.mul_assoc by auto with wf_db. + rewrite kron_mixed_product. + restore_dims. + rewrite (kron_comm_commutes_r _ _ _ _ A (I (2^padr))) by auto with wf_db. + rewrite (Nat.mul_comm (2^padm) (2^padr)). + rewrite <- kron_mixed_product. + rewrite <- kron_assoc by auto with wf_db. + rewrite !Mmult_assoc. + f_equal. + restore_dims. + rewrite !Nat.pow_add_r. + rewrite <- (Mmult_assoc (_ ⊗ _ ⊗ A)). + rewrite kron_mixed_product. + rewrite Mmult_1_r by auto. + rewrite id_kron. + restore_dims. + rewrite Mmult_1_l, <- (Mmult_1_r _ _ (perm_to_matrix _ _)) by auto with wf_db. + rewrite <- (Mmult_1_l _ _ A) by auto. + rewrite <- kron_mixed_product. + rewrite Mmult_1_r, Mmult_1_l by auto with wf_db. + rewrite Mmult_assoc. + f_equal. + rewrite kron_mixed_product, kron_comm_commutes_l by auto with wf_db. + rewrite <- kron_mixed_product. + restore_dims. + rewrite <- kron_assoc by auto with wf_db. + rewrite Nat.pow_add_r, <- id_kron. + now restore_dims. +Qed. + +Lemma perm_to_matrix_swap_perm_S a n (Ha : S a < n) : + perm_to_matrix n (swap_perm a (S a) n) = + I (2 ^ a) ⊗ swap ⊗ I (2 ^ (n - S (S a))). +Proof. + rewrite (perm_to_matrix_eq_of_perm_eq _ _ + (stack_perms (a + 2) (n - S (S a)) + (stack_perms a 2 idn (swap_perm 0 1 2)) idn)). + - rewrite 2!perm_to_matrix_of_stack_perms' by auto with perm_db zarith. + rewrite 2!perm_to_matrix_idn. + restore_dims. + do 2 f_equal. + apply mat_equiv_eq; auto with wf_db. + by_cell; reflexivity. + - cleanup_perm. + rewrite 2!swap_perm_defn by lia. + intros k Hk. + rewrite stack_perms_idn_f. + bdestructΩ'. +Qed. + + +Lemma enlarge_permutation_big_kron'_natural (ns ms : nat -> nat) + As (n : nat) (HAs : forall k, k < n -> WF_Matrix (As k)) + f (Hf : permutation n f) : + big_kron' ns ms As n × perm_to_matrix (big_sum ms n) + (enlarge_permutation n f ms) = + perm_to_matrix (big_sum ns n) + (enlarge_permutation n f ns) × + big_kron' (ns ∘ f) (ms ∘ f) + (fun i => As (f i)) n. +Proof. + symmetry. + assert (HAs' : forall k, k < n -> WF_Matrix (As (f k))) by + (intros k Hk; apply HAs; auto with perm_bounded_db). + rewrite Mmult_perm_to_matrix_l; + [ | | auto_wf | auto with perm_db]. + 2: { + apply WF_Matrix_dim_change; + [now rewrite <- Nsum_reorder by auto with perm_db..|]. + auto_wf. + } + apply equal_on_conj_basis_states_implies_equal. + - apply WF_Matrix_dim_change; + [now rewrite <- Nsum_reorder by auto with perm_db..|]. + auto_wf. + - auto_wf. + - intros g h. + symmetry. + rewrite !Mmult_assoc. + rewrite perm_to_matrix_permutes_qubits by auto with perm_db. + rewrite <- !Mmult_assoc. + rewrite perm_to_matrix_permutes_qubits_l by auto with perm_db. + rewrite 2!f_to_vec_big_split. + rewrite (Nsum_reorder n ns f Hf). + rewrite (Nsum_reorder n ms f Hf). + rewrite 2!f_to_vec_big_split. + rewrite 2!(big_kron'_transpose' _ (fun _ => 0) _ n _ 1). + restore_dims_using shelve. + rewrite big_kron'_Mmult by (intros; auto_wf). + rewrite big_kron'_Mmult by (intros; auto_wf). + rewrite big_kron'_Mmult by (unfold compose; intros; auto_wf). + rewrite big_kron'_Mmult by (unfold compose; intros; auto_wf). + rewrite (big_kron'_0_0_reorder _ n f (* (perm_inv' n f) *) + ltac:(auto with perm_db)) by (intros; auto_wf). + apply big_kron'_eq_bounded. + intros k Hk. + unfold compose at 1. + f_equal; [f_equal|]. + + f_equal. + apply f_to_vec_eq. + intros i Hi. + rewrite perm_inv_perm_inv. + * rewrite enlarge_permutation_add_big_sum_l + by auto with perm_bounded_db. + do 2 f_equal. + rewrite perm_inv'_eq by auto_perm. + rewrite perm_inv_is_linv_of_permutation by auto. + easy. + * rewrite <- Nsum_reorder by auto. + auto_perm. + * rewrite <- Nsum_reorder by auto. + rewrite (big_sum_split n (f k)) by auto_perm. + unfold compose in *. + cbn; lia. + + apply f_to_vec_eq. + intros i Hi. + unfold compose in Hi. + rewrite enlarge_permutation_add_big_sum_l by auto_perm. + do 2 f_equal. + rewrite perm_inv'_eq by auto with perm_db. + rewrite perm_inv_is_linv_of_permutation by auto. + easy. + Unshelve. + 1: now rewrite big_sum_0. + all: now rewrite <- Nsum_reorder by auto. +Qed. \ No newline at end of file diff --git a/Permutations.v b/Permutations.v index 88ce4fc..9dac898 100644 --- a/Permutations.v +++ b/Permutations.v @@ -1,1349 +1,3 @@ -Require Import Bits. -Require Import VectorStates. -Require Import Modulus. - -(** Facts about permutations and matrices that implement them. *) -Declare Scope perm_scope. -Local Open Scope perm_scope. -Local Open Scope nat_scope. - -Create HintDb perm_db. -Create HintDb perm_bounded_db. -Create HintDb perm_inv_db. -Create HintDb WF_perm_db. - -(** Permutations on (0, ..., n-1) *) -Definition permutation (n : nat) (f : nat -> nat) := - exists g, forall x, x < n -> (f x < n /\ g x < n /\ g (f x) = x /\ f (g x) = x). - -Lemma permutation_is_injective : forall n f, - permutation n f -> - forall x y, x < n -> y < n -> f x = f y -> x = y. -Proof. - intros n f [g Hbij] x y Hx Hy H. - destruct (Hbij x Hx) as [_ [_ [H0 _]]]. - destruct (Hbij y Hy) as [_ [_ [H1 _]]]. - rewrite <- H0. - rewrite <- H1. - rewrite H. - reflexivity. -Qed. - -Lemma permutation_is_surjective : forall n f, - permutation n f -> - forall k, k < n -> exists k', k' < n /\ f k' = k. -Proof. - intros n f Hf k Hk. - destruct Hf as [finv Hfinv]. - specialize (Hfinv k Hk). - exists (finv k). - intuition. -Qed. - -Lemma permutation_compose : forall n f g, - permutation n f -> - permutation n g -> - permutation n (f ∘ g)%prg. -Proof. - intros n f g [finv Hfbij] [ginv Hgbij]. - exists (ginv ∘ finv)%prg. - unfold compose. - intros x Hx. - destruct (Hgbij x) as [? [_ [? _]]]; auto. - destruct (Hfbij (g x)) as [? [_ [Hinv1 _]]]; auto. - destruct (Hfbij x) as [_ [? [_ ?]]]; auto. - destruct (Hgbij (finv x)) as [_ [? [_ Hinv2]]]; auto. - repeat split; auto. - rewrite Hinv1. - assumption. - rewrite Hinv2. - assumption. -Qed. - -(** The identity permutation *) -Notation idn := (fun (k : nat) => k). - -Lemma compose_idn_l : forall {T} (f : T -> nat), (idn ∘ f = f)%prg. -Proof. - intros. - unfold compose. - apply functional_extensionality; easy. -Qed. - -Lemma compose_idn_r : forall {T} (f : nat -> T), (f ∘ idn = f)%prg. -Proof. - intros. - unfold compose. - apply functional_extensionality; easy. -Qed. - -#[export] Hint Rewrite @compose_idn_r @compose_idn_l : perm_cleanup_db. - -Lemma idn_permutation : forall n, permutation n idn. -Proof. - intros. - exists idn. - easy. -Qed. - -Global Hint Resolve idn_permutation : perm_db. - -(** Notions of injectivity, boundedness, and surjectivity of f : nat -> nat - interpreted as a function from [n]_0 to [n]_0) and their equivalences *) -Notation perm_surj n f := (forall k, k < n -> exists k', k' < n /\ f k' = k). -Notation perm_bounded n f := (forall k, k < n -> f k < n). -Notation perm_inj n f := (forall k l, k < n -> l < n -> f k = f l -> k = l). - -Lemma fswap_injective_if_injective : forall {A} n (f:nat -> A) x y, - x < n -> y < n -> - perm_inj n f -> perm_inj n (fswap f x y). -Proof. - intros A n f x y Hx Hy Hinj k l Hk Hl. - unfold fswap. - bdestruct (k =? x); bdestruct (k =? y); - bdestruct (l =? x); bdestruct (l =? y); - subst; auto using Hinj. - all: intros Heq; - epose proof (Hinj _ _ _ _ Heq); - exfalso; lia. - Unshelve. - all: assumption. -Qed. - -Lemma fswap_injective_iff_injective : forall {A} n (f:nat -> A) x y, - x < n -> y < n -> - perm_inj n f <-> perm_inj n (fswap f x y). -Proof. - intros A n f x y Hx Hy. - split. - - apply fswap_injective_if_injective; easy. - - intros Hinj. - rewrite <- (fswap_involutive f x y). - apply fswap_injective_if_injective; easy. -Qed. - -Lemma fswap_surjective_if_surjective : forall n f x y, - x < n -> y < n -> - perm_surj n f -> perm_surj n (fswap f x y). -Proof. - intros n f x y Hx Hy Hsurj k Hk. - destruct (Hsurj k Hk) as [k' [Hk' Hfk']]. - bdestruct (k' =? x); [|bdestruct (k' =? y)]. - - exists y. - split; [assumption|]. - subst. - rewrite fswap_simpl2. - easy. - - exists x. - split; [assumption|]. - subst. - rewrite fswap_simpl1. - easy. - - exists k'. - split; [assumption|]. - rewrite fswap_neq; lia. -Qed. - -Lemma fswap_surjective_iff_surjective : forall n f x y, - x < n -> y < n -> - perm_surj n f <-> perm_surj n (fswap f x y). -Proof. - intros n f x y Hx Hy. - split. - - apply fswap_surjective_if_surjective; easy. - - intros Hsurj. - rewrite <- (fswap_involutive f x y). - apply fswap_surjective_if_surjective; easy. -Qed. - -Lemma fswap_bounded_if_bounded : forall n f x y, - x < n -> y < n -> - perm_bounded n f -> perm_bounded n (fswap f x y). -Proof. - intros n f x y Hx Hy Hbounded k Hk. - unfold fswap. - bdestruct_all; - apply Hbounded; - easy. -Qed. - -Lemma fswap_bounded_iff_bounded : forall n f x y, - x < n -> y < n -> - perm_bounded n f <-> perm_bounded n (fswap f x y). -Proof. - intros n f x y Hx Hy. - split. - - apply fswap_bounded_if_bounded; easy. - - intros Hbounded. - rewrite <- (fswap_involutive f x y). - apply fswap_bounded_if_bounded; easy. -Qed. - -Lemma surjective_of_eq_boundary_shrink : forall n f, - perm_surj (S n) f -> f n = n -> perm_surj n f. -Proof. - intros n f Hsurj Hfn k Hk. - assert (HkS : k < S n) by lia. - destruct (Hsurj k HkS) as [k' [Hk' Hfk']]. - bdestruct (k' =? n). - - exfalso; subst; lia. - - exists k'. - split; [lia | assumption]. -Qed. - -Lemma surjective_of_eq_boundary_grow : forall n f, - perm_surj n f -> f n = n -> perm_surj (S n) f. -Proof. - intros n f Hsurj Hfn k Hk. - bdestruct (k =? n). - - exists n; lia. - - assert (H'k : k < n) by lia. - destruct (Hsurj k H'k) as [k' [Hk' Hfk']]. - exists k'; lia. -Qed. - -Lemma fswap_at_boundary_surjective : forall n f n', - n' < S n -> perm_surj (S n) f -> f n' = n -> - perm_surj n (fswap f n' n). -Proof. - intros n f n' Hn' Hsurj Hfn' k Hk. - bdestruct (k =? f n). - - exists n'. - split. - + assert (Hneq: n' <> n); [|lia]. - intros Hfalse. - rewrite Hfalse in Hfn'. - rewrite Hfn' in H. - lia. - + rewrite fswap_simpl1; easy. - - assert (H'k : k < S n) by lia. - destruct (Hsurj k H'k) as [k' [Hk' Hfk']]. - assert (Hk'n: k' <> n) by (intros Hfalse; subst; lia). - assert (Hk'n': k' <> n') by (intros Hfalse; subst; lia). - exists k'. - split; [lia|]. - rewrite fswap_neq; lia. -Qed. - -Lemma injective_monotone : forall {A} n (f : nat -> A) m, - m < n -> perm_inj n f -> perm_inj m f. -Proof. - intros A n f m Hmn Hinj k l Hk Hl Hfkl. - apply Hinj; auto; lia. -Qed. - -Lemma injective_and_bounded_grow_of_boundary : forall n f, - perm_inj n f /\ perm_bounded n f -> f n = n -> - perm_inj (S n) f /\ perm_bounded (S n) f. -Proof. - intros n f [Hinj Hbounded] Hfn. - split. - - intros k l Hk Hl Hfkl. - bdestruct (k =? n). - + subst. - bdestruct (l =? n); [easy|]. - assert (H'l : l < n) by lia. - specialize (Hbounded _ H'l). - lia. - + assert (H'k : k < n) by lia. - bdestruct (l =? n). - * specialize (Hbounded _ H'k). - subst. lia. - * assert (H'l : l < n) by lia. - apply Hinj; easy. - - intros k Hk. - bdestruct (k perm_inj n f /\ perm_bounded n f. -Proof. - intros n. - induction n; [easy|]. - intros f Hsurj. - assert (HnS : n < S n) by lia. - destruct (Hsurj n HnS) as [n' [Hn' Hfn']]. - pose proof (fswap_at_boundary_surjective _ _ _ Hn' Hsurj Hfn') as Hswap_surj. - specialize (IHn (fswap f n' n) Hswap_surj). - rewrite (fswap_injective_iff_injective _ f n' n); [|easy|easy]. - rewrite (fswap_bounded_iff_bounded _ f n' n); [|easy|easy]. - apply injective_and_bounded_grow_of_boundary; - [| rewrite fswap_simpl2; easy]. - easy. -Qed. - -Lemma injective_and_bounded_shrink_of_boundary : forall n f, - perm_inj (S n) f /\ perm_bounded (S n) f -> f n = n -> - perm_inj n f /\ perm_bounded n f. -Proof. - intros n f [Hinj Hbounded] Hfn. - split. - - eapply injective_monotone, Hinj; lia. - - intros k Hk. - assert (H'k : k < S n) by lia. - specialize (Hbounded k H'k). - bdestruct (f k =? n). - + rewrite <- Hfn in H. - assert (HnS : n < S n) by lia. - specialize (Hinj _ _ H'k HnS H). - lia. - + lia. -Qed. - -(* Formalization of proof sketch of pigeonhole principle - from https://math.stackexchange.com/a/910790 *) -Lemma exists_bounded_decidable : forall n P, - (forall k, k < n -> {P k} + {~ P k}) -> - {exists j, j < n /\ P j} + {~ exists j, j < n /\ P j}. -Proof. - intros n P HPdec. - induction n. - - right; intros [x [Hlt0 _]]; inversion Hlt0. - - destruct (HPdec n) as [HPn | HnPn]; [lia| |]. - + left. exists n; split; [lia | assumption]. - + destruct IHn as [Hex | Hnex]. - * intros k Hk; apply HPdec; lia. - * left. - destruct Hex as [j [Hjn HPj]]. - exists j; split; [lia | assumption]. - * right. - intros [j [Hjn HPj]]. - apply Hnex. - bdestruct (j =? n). - -- exfalso; apply HnPn; subst; easy. - -- exists j; split; [lia | easy]. -Qed. - -Lemma has_preimage_decidable : forall n f, - forall k, k < n -> - {exists j, j < n /\ f j = k} + {~exists j, j < n /\ f j = k}. -Proof. - intros n f k Hk. - apply exists_bounded_decidable. - intros k' Hk'. - bdestruct (f k' =? k). - - left; easy. - - right; easy. -Qed. - -Lemma pigeonhole_S : forall n f, - (forall i, i < S n -> f i < n) -> - exists i j, i < S n /\ j < i /\ f i = f j. -Proof. - intros n. - destruct n; - [intros f Hbounded; specialize (Hbounded 0); lia|]. - induction n; intros f Hbounded. - 1: { - exists 1, 0. - pose (Hbounded 0). - pose (Hbounded 1). - lia. - } - destruct (has_preimage_decidable (S (S n)) f (f (S (S n)))) as [Hex | Hnex]. - - apply Hbounded; lia. - - destruct Hex as [j [Hj Hfj]]. - exists (S (S n)), j. - repeat split; lia. - - destruct (IHn (fun k => if f k f k >= f (S (S n)) -> f k > f (S (S n))). 1:{ - intros k Hk Hge. - bdestruct (f k =? f (S (S n))). - - exfalso; apply Hnex; exists k; split; lia. - - lia. - } - bdestruct (f i exists k, k < S n /\ f k = n. -Proof. - intros n f [Hinj Hbounded]. - destruct (has_preimage_decidable (S n) f n) as [Hex | Hnex]; - [lia | assumption |]. - (* Now, contradict injectivity using pigeonhole principle *) - exfalso. - assert (Hbounded': forall j, j < S n -> f j < n). 1:{ - intros j Hj. - specialize (Hbounded j Hj). - bdestruct (f j =? n). - - exfalso; apply Hnex; exists j; easy. - - lia. - } - destruct (pigeonhole_S n f Hbounded') as [i [j [Hi [Hj Heq]]]]. - absurd (i = j). - - lia. - - apply Hinj; lia. -Qed. - -Lemma surjective_of_injective_and_bounded : forall n f, - perm_inj n f /\ perm_bounded n f -> perm_surj n f. -Proof. - induction n; [easy|]. - intros f Hinj_bounded. - destruct (n_has_preimage_of_injective_and_bounded n f Hinj_bounded) as [n' [Hn' Hfn']]. - rewrite (fswap_injective_iff_injective _ _ n n') in Hinj_bounded; - [|lia|lia]. - rewrite (fswap_bounded_iff_bounded _ _ n n') in Hinj_bounded; - [|lia|lia]. - rewrite (fswap_surjective_iff_surjective _ _ n n'); - [|lia|easy]. - intros k Hk. - bdestruct (k =? n). - - exists n. - split; [lia|]. - rewrite fswap_simpl1; subst; easy. - - pose proof (injective_and_bounded_shrink_of_boundary n _ Hinj_bounded) as Hinj_bounded'. - rewrite fswap_simpl1 in Hinj_bounded'. - specialize (Hinj_bounded' Hfn'). - destruct (IHn (fswap f n n') Hinj_bounded' k) as [k' [Hk' Hfk']]; [lia|]. - exists k'. - split; [lia|assumption]. -Qed. - -(** Explicit inverse of a permutation *) -Fixpoint perm_inv n f k : nat := - match n with - | 0 => 0%nat - | S n' => if f n' =? k then n'%nat else perm_inv n' f k - end. - -Lemma perm_inv_bounded_S : forall n f k, - perm_inv (S n) f k < S n. -Proof. - intros n f k. - induction n; simpl. - - bdestructΩ (f 0 =? k). - - bdestruct (f (S n) =? k); [|transitivity (S n); [apply IHn|]]. - all: apply Nat.lt_succ_diag_r. -Qed. - -Lemma perm_inv_bounded : forall n f, - perm_bounded n (perm_inv n f). -Proof. - induction n. - - easy. - - intros. - apply perm_inv_bounded_S. -Qed. - -#[export] Hint Resolve perm_inv_bounded_S perm_inv_bounded : perm_bounded_db. - -Lemma perm_inv_is_linv_of_injective : forall n f, - perm_inj n f -> - forall k, k < n -> perm_inv n f (f k) = k. -Proof. - intros n f Hinj k Hk. - induction n. - - easy. - - simpl. - bdestruct (f n =? f k). - + apply Hinj; lia. - + assert (k <> n) by (intros Heq; subst; easy). - apply IHn; [auto|]. - assert (k <> n) by (intros Heq; subst; easy). - lia. -Qed. - -Lemma perm_inv_is_rinv_of_surjective' : forall n f k, - (exists l, l < n /\ f l = k) -> - f (perm_inv n f k) = k. -Proof. - intros n f k. - induction n. - - intros []; easy. - - intros [l [Hl Hfl]]. - simpl. - bdestruct (f n =? k); [easy|]. - apply IHn. - exists l. - split; [|easy]. - bdestruct (l =? n); [subst; easy|]. - lia. -Qed. - -Lemma perm_inv_is_rinv_of_surjective : forall n f, - perm_surj n f -> forall k, k < n -> - f (perm_inv n f k) = k. -Proof. - intros n f Hsurj k Hk. - apply perm_inv_is_rinv_of_surjective', Hsurj, Hk. -Qed. - -Lemma perm_inv_is_linv_of_permutation : forall n f, - permutation n f -> - forall k, k < n -> perm_inv n f (f k) = k. -Proof. - intros n f Hperm. - apply perm_inv_is_linv_of_injective, permutation_is_injective, Hperm. -Qed. - -Lemma perm_inv_is_rinv_of_permutation : forall n f, - permutation n f -> - forall k, k < n -> f (perm_inv n f k) = k. -Proof. - intros n f Hperm k Hk. - apply perm_inv_is_rinv_of_surjective', (permutation_is_surjective _ _ Hperm _ Hk). -Qed. - -Lemma perm_inv_is_inv_of_surjective_injective_bounded : forall n f, - perm_surj n f -> perm_inj n f -> perm_bounded n f -> - (forall k, k < n -> - f k < n /\ perm_inv n f k < n /\ perm_inv n f (f k) = k /\ f (perm_inv n f k) = k). -Proof. - intros n f Hsurj Hinj Hbounded. - intros k Hk; repeat split. - - apply Hbounded, Hk. - - apply perm_inv_bounded, Hk. - - rewrite perm_inv_is_linv_of_injective; easy. - - rewrite perm_inv_is_rinv_of_surjective'; [easy|]. - apply Hsurj; easy. -Qed. - -Lemma permutation_iff_surjective : forall n f, - permutation n f <-> perm_surj n f. -Proof. - split. - - apply permutation_is_surjective. - - intros Hsurj. - exists (perm_inv n f). - pose proof (injective_and_bounded_of_surjective n f Hsurj). - apply perm_inv_is_inv_of_surjective_injective_bounded; easy. -Qed. - -Lemma perm_inv_permutation n f : permutation n f -> - permutation n (perm_inv n f). -Proof. - intros Hperm. - exists f. - intros k Hk; repeat split. - - apply perm_inv_bounded, Hk. - - destruct Hperm as [? H]; apply H, Hk. - - rewrite perm_inv_is_rinv_of_permutation; easy. - - rewrite perm_inv_is_linv_of_permutation; easy. -Qed. - -#[export] Hint Resolve perm_inv_permutation : perm_db. - -Lemma permutation_is_bounded n f : permutation n f -> - perm_bounded n f. -Proof. - intros [finv Hfinv] k Hk. - destruct (Hfinv k Hk); easy. -Qed. - -Lemma id_permutation : forall n, - permutation n Datatypes.id. -Proof. - intros. - exists Datatypes.id. - intros. - unfold Datatypes.id. - easy. -Qed. - -Lemma fswap_permutation : forall n f x y, - permutation n f -> - (x < n)%nat -> - (y < n)%nat -> - permutation n (fswap f x y). -Proof. - intros. - replace (fswap f x y) with (f ∘ (fswap (fun i => i) x y))%prg. - apply permutation_compose; auto. - exists (fswap (fun i => i) x y). - intros. unfold fswap. - bdestruct_all; subst; auto. - apply functional_extensionality; intros. - unfold compose, fswap. - bdestruct_all; easy. -Qed. - -Lemma fswap_at_boundary_permutation : forall n f x, - permutation (S n) f -> - (x < S n)%nat -> f x = n -> - permutation n (fswap f x n). -Proof. - intros n f x. - rewrite 2!permutation_iff_surjective. - intros HsurjSn Hx Hfx. - apply fswap_at_boundary_surjective; easy. -Qed. - -(** Well-foundedness of permutations; f k = k for k not in [n]_0 *) -Definition WF_Perm (n : nat) (f : nat -> nat) := - forall k, n <= k -> f k = k. - -Lemma monotonic_WF_Perm n m f : WF_Perm n f -> n <= m -> - WF_Perm m f. -Proof. - intros HWF Hnm k Hk. - apply HWF; lia. -Qed. - -#[export] Hint Resolve monotonic_WF_Perm : WF_perm_db. - -Lemma compose_WF_Perm n f g : WF_Perm n f -> WF_Perm n g -> - WF_Perm n (f ∘ g)%prg. -Proof. - unfold compose. - intros Hf Hg k Hk. - rewrite Hg, Hf; easy. -Qed. - -#[export] Hint Resolve compose_WF_Perm : WF_perm_db. - -Lemma linv_WF_of_WF {n} {f finv} - (HfWF : WF_Perm n f) (Hinv : (finv ∘ f = idn)%prg) : - WF_Perm n finv. -Proof. - intros k Hk. - rewrite <- (HfWF k Hk). - unfold compose in Hinv. - apply (f_equal_inv k) in Hinv. - rewrite Hinv, (HfWF k Hk). - easy. -Qed. - -Lemma bounded_of_WF_linv {n} {f finv} - (HWF: WF_Perm n f) (Hinv : (finv ∘ f = idn)%prg) : - perm_bounded n f. -Proof. - intros k Hk. - pose proof (linv_WF_of_WF HWF Hinv) as HWFinv. - unfold compose in Hinv. - apply (f_equal_inv k) in Hinv. - bdestruct (f k nat) (n:nat) {finv finv'} - (Hf: permutation n f) (HfWF : WF_Perm n f) - (Hfinv : (finv ∘ f = idn)%prg) (Hfinv' : (finv' ∘ f = idn)%prg) : - finv = finv'. -Proof. - apply functional_extensionality; intros k. - pose proof (linv_WF_of_WF HfWF Hfinv) as HfinvWF. - pose proof (linv_WF_of_WF HfWF Hfinv') as Hfinv'WF. - bdestruct (n <=? k). - - rewrite HfinvWF, Hfinv'WF; easy. - - destruct Hf as [fi Hfi]. - specialize (Hfi k H). - unfold compose in Hfinv, Hfinv'. - apply (f_equal_inv (fi k)) in Hfinv, Hfinv'. - replace (f (fi k)) with k in * by easy. - rewrite Hfinv, Hfinv'. - easy. -Qed. - -Lemma permutation_monotonic_of_WF f m n : (m <= n)%nat -> - permutation m f -> WF_Perm m f -> - permutation n f. -Proof. - intros Hmn [finv_m Hfinv_m] HWF. - exists (fun k => if m <=? k then k else finv_m k). - intros k Hk. - bdestruct (m <=? k). - - rewrite HWF; bdestruct_all; auto. - - specialize (Hfinv_m _ H). - repeat split; bdestruct_all; try easy; lia. -Qed. - - -Notation perm_eq n f g := (forall k, k < n -> f k = g k). - -Lemma eq_of_WF_perm_eq n f g : WF_Perm n f -> WF_Perm n g -> - perm_eq n f g -> f = g. -Proof. - intros HfWF HgWF Heq. - apply functional_extensionality; intros k. - bdestruct (k perm_bounded n finv -> - perm_eq n (f ∘ finv)%prg idn <-> perm_eq n (finv ∘ f)%prg idn. -Proof. - intros Hperm Hbounded. - split; unfold compose. - - intros Hrinv. - intros k Hk. - apply (permutation_is_injective n f Hperm); try easy. - + apply Hbounded, permutation_is_bounded, Hk. - apply Hperm. - + rewrite Hrinv; [easy|]. - apply (permutation_is_bounded n f Hperm _ Hk). - - intros Hlinv k Hk. - destruct Hperm as [fi Hf]. - destruct (Hf k Hk) as [Hfk [Hfik [Hfifk Hffik]]]. - rewrite <- Hffik. - rewrite Hlinv; easy. -Qed. - -Notation is_perm_rinv n f finv := (perm_eq n (f ∘ finv)%prg idn). -Notation is_perm_linv n f finv := (perm_eq n (finv ∘ f)%prg idn). -Notation is_perm_inv n f finv := - (perm_eq n (f ∘ finv)%prg idn /\ perm_eq n (finv ∘ f)%prg idn). - -Lemma perm_linv_injective_of_surjective n f finv finv' : - perm_surj n f -> is_perm_linv n f finv -> is_perm_linv n f finv' -> - perm_eq n finv finv'. -Proof. - intros Hsurj Hfinv Hfinv' k Hk. - destruct (Hsurj k Hk) as [k' [Hk' Hfk']]. - rewrite <- Hfk'. - unfold compose in *. - rewrite Hfinv, Hfinv'; easy. -Qed. - -Lemma perm_bounded_rinv_injective_of_injective n f finv finv' : - perm_inj n f -> perm_bounded n finv -> perm_bounded n finv' -> - is_perm_rinv n f finv -> is_perm_rinv n f finv' -> - perm_eq n finv finv'. -Proof. - intros Hinj Hbounded Hbounded' Hfinv Hfinv' k Hk. - apply Hinj; auto. - unfold compose in *. - rewrite Hfinv, Hfinv'; easy. -Qed. - -Lemma permutation_inverse_injective n f finv finv' : permutation n f -> - is_perm_inv n f finv -> is_perm_inv n f finv' -> - perm_eq n finv finv'. -Proof. - intros Hperm Hfinv Hfinv'. - eapply perm_linv_injective_of_surjective. - + apply permutation_is_surjective, Hperm. - + destruct (Hfinv); auto. - + destruct (Hfinv'); auto. -Qed. - -Fixpoint for_all_nat_lt (f : nat -> bool) (k : nat) := - match k with - | 0 => true - | S k' => f k' && for_all_nat_lt f k' - end. - -Lemma forall_nat_lt_S (P : forall k : nat, Prop) (n : nat) : - (forall k, k < S n -> P k) <-> P n /\ (forall k, k < n -> P k). -Proof. - split. - - intros Hall. - split; intros; apply Hall; lia. - - intros [Hn Hall]. - intros k Hk. - bdestruct (k=?n); [subst; easy | apply Hall; lia]. -Qed. - -Lemma for_all_nat_ltE {f : nat -> bool} {P : forall k : nat, Prop} - (ref : forall k, reflect (P k) (f k)) : - forall n, (forall k, k < n -> P k) <-> (for_all_nat_lt f n = true). -Proof. - induction n. - - easy. - - rewrite forall_nat_lt_S. - simpl. - rewrite andb_true_iff. - rewrite IHn. - apply and_iff_compat_r. - apply reflect_iff; easy. -Qed. - -Definition perm_inv_is_inv_pred (f : nat -> nat) (n : nat) : Prop := - forall k, k < n -> - f k < n /\ perm_inv n f k < n /\ - perm_inv n f (f k) = k /\ f (perm_inv n f k) = k. - -Definition is_permutation (f : nat -> nat) (n : nat) := - for_all_nat_lt - (fun k => - (f k nat) (n : nat) : - permutation n f <-> perm_inv_is_inv_pred f n. -Proof. - split. - - intros Hperm. - intros k Hk. - repeat split. - + destruct Hperm as [g Hg]; - apply (Hg k Hk). - + apply perm_inv_bounded; easy. - + apply perm_inv_is_linv_of_permutation; easy. - + apply perm_inv_is_rinv_of_permutation; easy. - - intros Hperminv. - exists (perm_inv n f); easy. -Qed. - -Lemma is_permutationE (f : nat -> nat) (n : nat) : - perm_inv_is_inv_pred f n <-> is_permutation f n = true. -Proof. - unfold perm_inv_is_inv_pred, is_permutation. - apply for_all_nat_ltE. - intros k. - apply iff_reflect. - rewrite 3!andb_true_iff. - rewrite 2!Nat.ltb_lt, 2!Nat.eqb_eq, 2!and_assoc. - easy. -Qed. - -Lemma permutation_iff_is_permutation (f : nat -> nat) (n : nat) : - permutation n f <-> is_permutation f n = true. -Proof. - rewrite permutation_iff_perm_inv_is_inv. - apply is_permutationE. -Qed. - -Lemma permutationP (f : nat -> nat) (n : nat) : - reflect (permutation n f) (is_permutation f n). -Proof. - apply iff_reflect, permutation_iff_is_permutation. -Qed. - -Definition permutation_dec (f : nat -> nat) (n : nat) : - {permutation n f} + {~ permutation n f} := - reflect_dec _ _ (permutationP f n). - - -(** vsum terms can be arbitrarily reordered *) -Lemma vsum_reorder : forall {d} n (v : nat -> Vector d) f, - permutation n f -> - big_sum v n = big_sum (fun i => v (f i)) n. -Proof. - intros. - generalize dependent f. - induction n. - reflexivity. - intros f [g Hg]. - destruct (Hg n) as [_ [H1 [_ H2]]]; try lia. - rewrite (vsum_eq_up_to_fswap _ f _ (g n) n) by auto. - repeat rewrite <- big_sum_extend_r. - rewrite fswap_simpl2. - rewrite H2. - specialize (IHn (fswap f (g n) n)). - rewrite <- IHn. - reflexivity. - apply fswap_at_boundary_permutation; auto. - exists g. auto. -Qed. - -(** showing every permutation is a sequence of fswaps *) - -(* note the list acts on the left, for example, [s1,s2,...,sk] ⋅ f = s1 ⋅ ( ... ⋅ (sk ⋅ f)) *) -Fixpoint stack_fswaps (f : nat -> nat) (l : list (nat * nat)) := - match l with - | [] => f - | p :: ps => (fswap (Datatypes.id) (fst p) (snd p) ∘ (stack_fswaps f ps))%prg - end. - -Definition WF_fswap_stack n (l : list (nat * nat)) := - forall p, In p l -> (fst p < n /\ snd p < n). - -Lemma WF_fswap_stack_pop : forall n a l, - WF_fswap_stack n (a :: l) -> WF_fswap_stack n l. -Proof. intros. - unfold WF_fswap_stack in *. - intros. - apply H. - right; easy. -Qed. - -Lemma WF_fswap_stack_cons : forall n a l, - fst a < n -> snd a < n -> WF_fswap_stack n l -> WF_fswap_stack n (a :: l). -Proof. intros. - unfold WF_fswap_stack in *. - intros. - destruct H2; subst; auto. -Qed. - -Lemma WF_fswap_miss : forall n l i, - WF_fswap_stack n l -> - n <= i -> - (stack_fswaps Datatypes.id l) i = i. -Proof. induction l. - intros; simpl; easy. - intros; simpl. - unfold compose. - rewrite IHl; auto. - unfold fswap, Datatypes.id; simpl. - destruct (H a). - left; auto. - bdestruct_all; try lia. - apply WF_fswap_stack_pop in H; auto. -Qed. - -Lemma stack_fswaps_permutation : forall {n} (f : nat -> nat) (l : list (nat * nat)), - WF_fswap_stack n l -> - permutation n f -> - permutation n (stack_fswaps f l). -Proof. induction l. - - intros. easy. - - intros. - simpl. - apply permutation_compose. - apply fswap_permutation. - apply id_permutation. - 3 : apply IHl; auto. - 3 : apply WF_fswap_stack_pop in H; auto. - all : apply H; left; easy. -Qed. - -Lemma stack_fswaps_cons : forall (p : nat * nat) (l : list (nat * nat)), - ((stack_fswaps Datatypes.id [p]) ∘ (stack_fswaps Datatypes.id l))%prg = - stack_fswaps Datatypes.id (p :: l). -Proof. intros. - simpl. - rewrite compose_id_right. - easy. -Qed. - -(* -Theorem all_perms_are_fswap_stacks : forall {n} f, - permutation n f -> - exists l, WF_fswap_stack n l /\ f = (stack_fswaps Datatypes.id l) /\ length l = n. -Proof. induction n. - - intros. - exists []; simpl. -*) - -Definition ordered_real_function n (f : nat -> R) := - forall i j, i < n -> j < n -> i <= j -> (f j <= f i)%R. - -Lemma get_real_function_min : forall {n} (f : nat -> R), - exists n0, (n0 < (S n))%nat /\ (forall i, (i < (S n))%nat -> (f n0 <= f i)%R). -Proof. induction n. - - intros. - exists O; intros. - split; auto. - intros. - destruct i; try lia. - lra. - - intros. - destruct (IHn f) as [n0 [H H0] ]. - destruct (Rlt_le_dec (f n0) (f (S n))). - + exists n0; intros. - split; try lia. - intros. - bdestruct (i =? (S n))%nat; subst. - lra. - apply H0. - bdestruct (n0 R), - exists l, WF_fswap_stack n l /\ - ordered_real_function n (f ∘ (stack_fswaps Datatypes.id l))%prg. -Proof. intros. - generalize dependent f. - induction n. - - intros; exists []. - split; auto. - unfold WF_fswap_stack; intros. - destruct H. - simpl. - unfold ordered_real_function; intros; lia. - - intros. - destruct (@get_real_function_min n f) as [n0 [H H0]]. - destruct (IHn (f ∘ (stack_fswaps Datatypes.id [(n0, n)]))%prg) as [l [H1 H2]]. - exists ((n0, n) :: l). - split. - apply WF_fswap_stack_cons; simpl; auto. - unfold WF_fswap_stack in *; intros. - apply H1 in H3. - lia. - rewrite compose_assoc, stack_fswaps_cons in H2. - unfold ordered_real_function in *. - intros. - bdestruct (j =? n); subst. - simpl. - rewrite <- compose_assoc. - assert (H' : permutation (S n) - (fswap Datatypes.id n0 n ∘ stack_fswaps Datatypes.id l)%prg). - { apply permutation_compose. - apply fswap_permutation; auto. - apply id_permutation. - apply stack_fswaps_permutation. - unfold WF_fswap_stack in *; intros. - apply H1 in H6. - lia. - apply id_permutation. } - unfold compose in *. - destruct H' as [g H6]. - destruct (H6 i); auto. - rewrite (WF_fswap_miss n); auto. - replace (fswap Datatypes.id n0 n n) with n0. - apply H0; easy. - unfold fswap, Datatypes.id. - bdestruct_all; simpl; easy. - bdestruct (j nat) : Square n := - (fun x y => if (x =? p y) && (x WF_Unitary (perm_mat n p). -Proof. - intros n p [pinv Hp]. - split. - apply perm_mat_WF. - unfold Mmult, adjoint, perm_mat, I. - prep_matrix_equality. - destruct ((x =? y) && (x - perm_mat n f × perm_mat n g = perm_mat n (f ∘ g)%prg. -Proof. - intros n f g [ginv Hgbij]. - unfold perm_mat, Mmult, compose. - prep_matrix_equality. - destruct ((x =? f (g y)) && (x f x = x) -> - perm_mat n f = I n. -Proof. - intros n f Hinv. - unfold perm_mat, I. - prep_matrix_equality. - bdestruct_all; simpl; try lca. - rewrite Hinv in H1 by assumption. - contradiction. - rewrite Hinv in H1 by assumption. - contradiction. -Qed. - -Lemma perm_mat_col_swap_I : forall n f i j, - (forall x, x < n -> f x = x) -> - i < n -> j < n -> - perm_mat n (fswap f i j) = col_swap (I n) i j. -Proof. intros. - unfold perm_mat, fswap, col_swap, I. - prep_matrix_equality. - rewrite 2 H; auto. - bdestruct_all; simpl; try lia; auto. - rewrite H in H4; auto; lia. - rewrite H in H4; auto; lia. -Qed. - -Lemma perm_mat_col_swap : forall n f i j, - i < n -> j < n -> - perm_mat n (fswap f i j) = col_swap (perm_mat n f) i j. -Proof. intros. - unfold perm_mat, fswap, col_swap, I. - prep_matrix_equality. - bdestruct_all; simpl; try lia; auto. -Qed. - -Lemma perm_mat_row_swap : forall n f i j, - i < n -> j < n -> - perm_mat n (fswap f i j) = (row_swap (perm_mat n f)† i j)†. -Proof. intros. - unfold perm_mat, fswap, row_swap, I, adjoint. - prep_matrix_equality. - bdestruct_all; simpl; try lia; auto; lca. -Qed. - -Lemma perm_mat_e_i : forall n f i, - i < n -> - permutation n f -> - (perm_mat n f) × e_i i = e_i (f i). -Proof. intros. - apply mat_equiv_eq; auto with wf_db. - unfold mat_equiv; intros. - destruct j; try lia. - unfold Mmult. - apply big_sum_unique. - exists i. - split; auto. - split. - unfold e_i, perm_mat. - bdestruct_all; simpl; lca. - intros. - unfold e_i. - bdestruct_all; simpl; lca. -Qed. - -(* with get_entry_with_e_i this became soo much easier *) -Lemma perm_mat_conjugate : forall {n} (A : Square n) f (i j : nat), - WF_Matrix A -> - i < n -> j < n -> - permutation n f -> - ((perm_mat n f)† × A × ((perm_mat n f))) i j = A (f i) (f j). -Proof. intros. - rewrite get_entry_with_e_i, (get_entry_with_e_i A); auto. - rewrite <- 2 Mmult_assoc, <- Mmult_adjoint. - rewrite perm_mat_e_i; auto. - rewrite 3 Mmult_assoc. - rewrite perm_mat_e_i; auto. - all : destruct H2; apply H2; auto. -Qed. - -Lemma perm_mat_conjugate_nonsquare : forall {m n} (A : Matrix m n) f (i j : nat), - WF_Matrix A -> - i < m -> j < n -> - permutation m f -> permutation n f -> - ((perm_mat m f)† × A × ((perm_mat n f))) i j = A (f i) (f j). -Proof. intros. - rewrite get_entry_with_e_i, (get_entry_with_e_i A); auto. - rewrite <- 2 Mmult_assoc, <- Mmult_adjoint. - rewrite perm_mat_e_i; auto. - rewrite 3 Mmult_assoc. - rewrite perm_mat_e_i; auto. - all : destruct H2; destruct H3; try apply H2; try apply H3; auto. -Qed. - -(** Given a permutation p over n qubits, construct a permutation over 2^n indices. *) -Definition qubit_perm_to_nat_perm n (p : nat -> nat) := - fun x:nat => funbool_to_nat n ((nat_to_funbool n x) ∘ p)%prg. - -Lemma qubit_perm_to_nat_perm_bij : forall n p, - permutation n p -> permutation (2^n) (qubit_perm_to_nat_perm n p). -Proof. - intros n p [pinv Hp]. - unfold qubit_perm_to_nat_perm. - exists (fun x => funbool_to_nat n ((nat_to_funbool n x) ∘ pinv)%prg). - intros x Hx. - repeat split. - apply funbool_to_nat_bound. - apply funbool_to_nat_bound. - unfold compose. - erewrite funbool_to_nat_eq. - 2: { intros y Hy. - rewrite funbool_to_nat_inverse. - destruct (Hp y) as [_ [_ [_ H]]]. - assumption. - rewrite H. - reflexivity. - destruct (Hp y) as [_ [? _]]; auto. } - rewrite nat_to_funbool_inverse; auto. - unfold compose. - erewrite funbool_to_nat_eq. - 2: { intros y Hy. - rewrite funbool_to_nat_inverse. - destruct (Hp y) as [_ [_ [H _]]]. - assumption. - rewrite H. - reflexivity. - destruct (Hp y) as [? _]; auto. } - rewrite nat_to_funbool_inverse; auto. -Qed. - -(** Transform a (0,...,n-1) permutation into a 2^n by 2^n matrix. *) -Definition perm_to_matrix n p := - perm_mat (2 ^ n) (qubit_perm_to_nat_perm n p). - -Lemma perm_to_matrix_permutes_qubits : forall n p f, - permutation n p -> - perm_to_matrix n p × f_to_vec n f = f_to_vec n (fun x => f (p x)). -Proof. - intros n p f [pinv Hp]. - rewrite 2 basis_f_to_vec. - unfold perm_to_matrix, perm_mat, qubit_perm_to_nat_perm. - unfold basis_vector, Mmult, compose. - prep_matrix_equality. - destruct ((x =? funbool_to_nat n (fun x0 : nat => f (p x0))) && (y =? 0)) eqn:H. - apply andb_prop in H as [H1 H2]. - rewrite Nat.eqb_eq in H1. - rewrite Nat.eqb_eq in H2. - apply big_sum_unique. - exists (funbool_to_nat n f). - split. - apply funbool_to_nat_bound. - split. - erewrite funbool_to_nat_eq. - 2: { intros. rewrite funbool_to_nat_inverse. reflexivity. - destruct (Hp x0) as [? _]; auto. } - specialize (funbool_to_nat_bound n f) as ?. - specialize (funbool_to_nat_bound n (fun x0 : nat => f (p x0))) as ?. - bdestruct_all; lca. - intros z Hz H3. - bdestructΩ (z =? funbool_to_nat n f). - lca. - apply (@big_sum_0 C C_is_monoid). - intros z. - bdestruct_all; simpl; try lca. - rewrite andb_true_r in H. - apply Nat.eqb_neq in H. - subst z. - erewrite funbool_to_nat_eq in H2. - 2: { intros. rewrite funbool_to_nat_inverse. reflexivity. - destruct (Hp x0) as [? _]; auto. } - contradiction. -Qed. - -Lemma perm_to_matrix_unitary : forall n p, - permutation n p -> - WF_Unitary (perm_to_matrix n p). -Proof. - intros. - apply perm_mat_unitary. - apply qubit_perm_to_nat_perm_bij. - assumption. -Qed. - -Lemma qubit_perm_to_nat_perm_compose : forall n f g, - permutation n f -> - (qubit_perm_to_nat_perm n f ∘ qubit_perm_to_nat_perm n g = - qubit_perm_to_nat_perm n (g ∘ f))%prg. -Proof. - intros n f g [finv Hbij]. - unfold qubit_perm_to_nat_perm, compose. - apply functional_extensionality. - intro x. - apply funbool_to_nat_eq. - intros y Hy. - rewrite funbool_to_nat_inverse. - reflexivity. - destruct (Hbij y) as [? _]; auto. -Qed. - -Lemma perm_to_matrix_Mmult : forall n f g, - permutation n f -> - permutation n g -> - perm_to_matrix n f × perm_to_matrix n g = perm_to_matrix n (g ∘ f)%prg. -Proof. - intros. - unfold perm_to_matrix. - rewrite perm_mat_Mmult. - rewrite qubit_perm_to_nat_perm_compose by assumption. - reflexivity. - apply qubit_perm_to_nat_perm_bij. - assumption. -Qed. - -Lemma perm_to_matrix_I : forall n f, - permutation n f -> - (forall x, x < n -> f x = x) -> - perm_to_matrix n f = I (2 ^ n). -Proof. - intros n f g Hbij. - unfold perm_to_matrix. - apply perm_mat_I. - intros x Hx. - unfold qubit_perm_to_nat_perm, compose. - erewrite funbool_to_nat_eq. - 2: { intros y Hy. rewrite Hbij by assumption. reflexivity. } - apply nat_to_funbool_inverse. - assumption. -Qed. - -Lemma perm_to_matrix_WF : forall n p, WF_Matrix (perm_to_matrix n p). -Proof. intros. apply perm_mat_WF. Qed. -#[export] Hint Resolve perm_to_matrix_WF : wf_db. +Require Export PermutationAutomation. +Require Export PermutationInstances. +Require Export PermutationMatrices. \ No newline at end of file diff --git a/PermutationsBase.v b/PermutationsBase.v new file mode 100644 index 0000000..91fe40e --- /dev/null +++ b/PermutationsBase.v @@ -0,0 +1,1262 @@ +Require Import Bits. +Require Import Modulus. + +(** Facts about permutations *) +Declare Scope perm_scope. +Local Open Scope perm_scope. +Local Open Scope nat_scope. + +Create HintDb perm_db. +Create HintDb perm_bounded_db. +Create HintDb perm_inv_db. +Create HintDb WF_Perm_db. + +(** Permutations on (0, ..., n-1) *) +Definition permutation (n : nat) (f : nat -> nat) := + exists g, forall x, x < n -> (f x < n /\ g x < n /\ g (f x) = x /\ f (g x) = x). + +Lemma permutation_is_injective : forall n f, + permutation n f -> + forall x y, x < n -> y < n -> f x = f y -> x = y. +Proof. + intros n f [g Hbij] x y Hx Hy H. + destruct (Hbij x Hx) as [_ [_ [H0 _]]]. + destruct (Hbij y Hy) as [_ [_ [H1 _]]]. + rewrite <- H0. + rewrite <- H1. + rewrite H. + reflexivity. +Qed. + +Lemma permutation_is_surjective : forall n f, + permutation n f -> + forall k, k < n -> exists k', k' < n /\ f k' = k. +Proof. + intros n f Hf k Hk. + destruct Hf as [finv Hfinv]. + specialize (Hfinv k Hk). + exists (finv k). + intuition. +Qed. + +Lemma permutation_compose : forall n f g, + permutation n f -> + permutation n g -> + permutation n (f ∘ g)%prg. +Proof. + intros n f g [finv Hfbij] [ginv Hgbij]. + exists (ginv ∘ finv)%prg. + unfold compose. + intros x Hx. + destruct (Hgbij x) as [? [_ [? _]]]; auto. + destruct (Hfbij (g x)) as [? [_ [Hinv1 _]]]; auto. + destruct (Hfbij x) as [_ [? [_ ?]]]; auto. + destruct (Hgbij (finv x)) as [_ [? [_ Hinv2]]]; auto. + repeat split; auto. + rewrite Hinv1. + assumption. + rewrite Hinv2. + assumption. +Qed. + +(** The identity permutation *) +Notation idn := (fun (k : nat) => k). + +Lemma compose_idn_l : forall {T} (f : T -> nat), (idn ∘ f = f)%prg. +Proof. + intros. + unfold compose. + apply functional_extensionality; easy. +Qed. + +Lemma compose_idn_r : forall {T} (f : nat -> T), (f ∘ idn = f)%prg. +Proof. + intros. + unfold compose. + apply functional_extensionality; easy. +Qed. + +#[export] Hint Rewrite @compose_idn_r @compose_idn_l : perm_cleanup_db. + +Lemma idn_permutation : forall n, permutation n idn. +Proof. + intros. + exists idn. + easy. +Qed. + +Global Hint Resolve idn_permutation : perm_db. + +(** Notions of injectivity, boundedness, and surjectivity of f : nat -> nat + interpreted as a function from [n]_0 to [n]_0) and their equivalences *) +Notation perm_surj n f := (forall k, k < n -> exists k', k' < n /\ f k' = k). +Notation perm_bounded n f := (forall k, k < n -> f k < n). +Notation perm_inj n f := (forall k l, k < n -> l < n -> f k = f l -> k = l). + +Lemma fswap_injective_if_injective : forall {A} n (f:nat -> A) x y, + x < n -> y < n -> + perm_inj n f -> perm_inj n (fswap f x y). +Proof. + intros A n f x y Hx Hy Hinj k l Hk Hl. + unfold fswap. + bdestruct (k =? x); bdestruct (k =? y); + bdestruct (l =? x); bdestruct (l =? y); + subst; + intros Heq; + apply Hinj in Heq; lia. +Qed. + +Lemma fswap_injective_iff_injective : forall {A} n (f:nat -> A) x y, + x < n -> y < n -> + perm_inj n f <-> perm_inj n (fswap f x y). +Proof. + intros A n f x y Hx Hy. + split. + - apply fswap_injective_if_injective; easy. + - intros Hinj. + rewrite <- (fswap_involutive f x y). + apply fswap_injective_if_injective; easy. +Qed. + +Lemma fswap_surjective_if_surjective : forall n f x y, + x < n -> y < n -> + perm_surj n f -> perm_surj n (fswap f x y). +Proof. + intros n f x y Hx Hy Hsurj k Hk. + destruct (Hsurj k Hk) as [k' [Hk' Hfk']]. + bdestruct (k' =? x); [|bdestruct (k' =? y)]. + - exists y. + split; [assumption|]. + subst. + rewrite fswap_simpl2. + easy. + - exists x. + split; [assumption|]. + subst. + rewrite fswap_simpl1. + easy. + - exists k'. + split; [assumption|]. + rewrite fswap_neq; lia. +Qed. + +Lemma fswap_surjective_iff_surjective : forall n f x y, + x < n -> y < n -> + perm_surj n f <-> perm_surj n (fswap f x y). +Proof. + intros n f x y Hx Hy. + split. + - apply fswap_surjective_if_surjective; easy. + - intros Hsurj. + rewrite <- (fswap_involutive f x y). + apply fswap_surjective_if_surjective; easy. +Qed. + +Lemma fswap_bounded_if_bounded : forall n f x y, + x < n -> y < n -> + perm_bounded n f -> perm_bounded n (fswap f x y). +Proof. + intros n f x y Hx Hy Hbounded k Hk. + unfold fswap. + bdestruct_all; + apply Hbounded; + easy. +Qed. + +Lemma fswap_bounded_iff_bounded : forall n f x y, + x < n -> y < n -> + perm_bounded n f <-> perm_bounded n (fswap f x y). +Proof. + intros n f x y Hx Hy. + split. + - apply fswap_bounded_if_bounded; easy. + - intros Hbounded. + rewrite <- (fswap_involutive f x y). + apply fswap_bounded_if_bounded; easy. +Qed. + +Lemma surjective_of_eq_boundary_shrink : forall n f, + perm_surj (S n) f -> f n = n -> perm_surj n f. +Proof. + intros n f Hsurj Hfn k Hk. + assert (HkS : k < S n) by lia. + destruct (Hsurj k HkS) as [k' [Hk' Hfk']]. + bdestruct (k' =? n). + - exfalso; subst; lia. + - exists k'. + split; [lia | assumption]. +Qed. + +Lemma surjective_of_eq_boundary_grow : forall n f, + perm_surj n f -> f n = n -> perm_surj (S n) f. +Proof. + intros n f Hsurj Hfn k Hk. + bdestruct (k =? n). + - exists n; lia. + - assert (H'k : k < n) by lia. + destruct (Hsurj k H'k) as [k' [Hk' Hfk']]. + exists k'; lia. +Qed. + +Lemma fswap_at_boundary_surjective : forall n f n', + n' < S n -> perm_surj (S n) f -> f n' = n -> + perm_surj n (fswap f n' n). +Proof. + intros n f n' Hn' Hsurj Hfn' k Hk. + bdestruct (k =? f n). + - exists n'. + split. + + assert (Hneq: n' <> n); [|lia]. + intros Hfalse. + rewrite Hfalse in Hfn'. + rewrite Hfn' in H. + lia. + + rewrite fswap_simpl1; easy. + - assert (H'k : k < S n) by lia. + destruct (Hsurj k H'k) as [k' [Hk' Hfk']]. + assert (Hk'n: k' <> n) by (intros Hfalse; subst; lia). + assert (Hk'n': k' <> n') by (intros Hfalse; subst; lia). + exists k'. + split; [lia|]. + rewrite fswap_neq; lia. +Qed. + +Lemma injective_monotone : forall {A} n (f : nat -> A) m, + m < n -> perm_inj n f -> perm_inj m f. +Proof. + intros A n f m Hmn Hinj k l Hk Hl Hfkl. + apply Hinj; auto; lia. +Qed. + +Lemma injective_and_bounded_grow_of_boundary : forall n f, + perm_inj n f /\ perm_bounded n f -> f n = n -> + perm_inj (S n) f /\ perm_bounded (S n) f. +Proof. + intros n f [Hinj Hbounded] Hfn. + split. + - intros k l Hk Hl Hfkl. + bdestruct (k =? n). + + subst. + bdestruct (l =? n); [easy|]. + assert (H'l : l < n) by lia. + specialize (Hbounded _ H'l). + lia. + + assert (H'k : k < n) by lia. + bdestruct (l =? n). + * specialize (Hbounded _ H'k). + subst. lia. + * assert (H'l : l < n) by lia. + apply Hinj; easy. + - intros k Hk. + bdestruct (k perm_inj n f /\ perm_bounded n f. +Proof. + intros n. + induction n; [easy|]. + intros f Hsurj. + assert (HnS : n < S n) by lia. + destruct (Hsurj n HnS) as [n' [Hn' Hfn']]. + pose proof (fswap_at_boundary_surjective _ _ _ Hn' Hsurj Hfn') as Hswap_surj. + specialize (IHn (fswap f n' n) Hswap_surj). + rewrite (fswap_injective_iff_injective _ f n' n); [|easy|easy]. + rewrite (fswap_bounded_iff_bounded _ f n' n); [|easy|easy]. + apply injective_and_bounded_grow_of_boundary; + [| rewrite fswap_simpl2; easy]. + easy. +Qed. + +Lemma injective_and_bounded_shrink_of_boundary : forall n f, + perm_inj (S n) f /\ perm_bounded (S n) f -> f n = n -> + perm_inj n f /\ perm_bounded n f. +Proof. + intros n f [Hinj Hbounded] Hfn. + split. + - eapply injective_monotone, Hinj; lia. + - intros k Hk. + assert (H'k : k < S n) by lia. + specialize (Hbounded k H'k). + bdestruct (f k =? n). + + rewrite <- Hfn in H. + assert (HnS : n < S n) by lia. + specialize (Hinj _ _ H'k HnS H). + lia. + + lia. +Qed. + +(* Formalization of proof sketch of pigeonhole principle + from https://math.stackexchange.com/a/910790 *) +Lemma exists_bounded_decidable : forall n P, + (forall k, k < n -> {P k} + {~ P k}) -> + {exists j, j < n /\ P j} + {~ exists j, j < n /\ P j}. +Proof. + intros n P HPdec. + induction n. + - right; intros [x [Hlt0 _]]; inversion Hlt0. + - destruct (HPdec n) as [HPn | HnPn]; [lia| |]. + + left. exists n; split; [lia | assumption]. + + destruct IHn as [Hex | Hnex]. + * intros k Hk; apply HPdec; lia. + * left. + destruct Hex as [j [Hjn HPj]]. + exists j; split; [lia | assumption]. + * right. + intros [j [Hjn HPj]]. + apply Hnex. + bdestruct (j =? n). + -- exfalso; apply HnPn; subst; easy. + -- exists j; split; [lia | easy]. +Qed. + +Lemma has_preimage_decidable : forall n f, + forall k, k < n -> + {exists j, j < n /\ f j = k} + {~exists j, j < n /\ f j = k}. +Proof. + intros n f k Hk. + apply exists_bounded_decidable. + intros k' Hk'. + bdestruct (f k' =? k). + - left; easy. + - right; easy. +Qed. + +Lemma pigeonhole_S : forall n f, + (forall i, i < S n -> f i < n) -> + exists i j, i < S n /\ j < i /\ f i = f j. +Proof. + intros n. + destruct n; + [intros f Hbounded; specialize (Hbounded 0); lia|]. + induction n; intros f Hbounded. + 1: { + exists 1, 0. + pose (Hbounded 0). + pose (Hbounded 1). + lia. + } + destruct (has_preimage_decidable (S (S n)) f (f (S (S n)))) as [Hex | Hnex]. + - apply Hbounded; lia. + - destruct Hex as [j [Hj Hfj]]. + exists (S (S n)), j. + repeat split; lia. + - destruct (IHn (fun k => if f k f k >= f (S (S n)) -> f k > f (S (S n))). 1:{ + intros k Hk Hge. + bdestruct (f k =? f (S (S n))). + - exfalso; apply Hnex; exists k; split; lia. + - lia. + } + bdestruct (f i exists k, k < S n /\ f k = n. +Proof. + intros n f [Hinj Hbounded]. + destruct (has_preimage_decidable (S n) f n) as [Hex | Hnex]; + [lia | assumption |]. + (* Now, contradict injectivity using pigeonhole principle *) + exfalso. + assert (Hbounded': forall j, j < S n -> f j < n). 1:{ + intros j Hj. + specialize (Hbounded j Hj). + bdestruct (f j =? n). + - exfalso; apply Hnex; exists j; easy. + - lia. + } + destruct (pigeonhole_S n f Hbounded') as [i [j [Hi [Hj Heq]]]]. + absurd (i = j). + - lia. + - apply Hinj; lia. +Qed. + +Lemma surjective_of_injective_and_bounded : forall n f, + perm_inj n f /\ perm_bounded n f -> perm_surj n f. +Proof. + induction n; [easy|]. + intros f Hinj_bounded. + destruct (n_has_preimage_of_injective_and_bounded n f Hinj_bounded) as [n' [Hn' Hfn']]. + rewrite (fswap_injective_iff_injective _ _ n n') in Hinj_bounded; + [|lia|lia]. + rewrite (fswap_bounded_iff_bounded _ _ n n') in Hinj_bounded; + [|lia|lia]. + rewrite (fswap_surjective_iff_surjective _ _ n n'); + [|lia|easy]. + intros k Hk. + bdestruct (k =? n). + - exists n. + split; [lia|]. + rewrite fswap_simpl1; subst; easy. + - pose proof (injective_and_bounded_shrink_of_boundary n _ Hinj_bounded) as Hinj_bounded'. + rewrite fswap_simpl1 in Hinj_bounded'. + specialize (Hinj_bounded' Hfn'). + destruct (IHn (fswap f n n') Hinj_bounded' k) as [k' [Hk' Hfk']]; [lia|]. + exists k'. + split; [lia|assumption]. +Qed. + +(** Explicit inverse of a permutation *) +Fixpoint perm_inv n f k : nat := + match n with + | 0 => 0%nat + | S n' => if f n' =? k then n'%nat else perm_inv n' f k + end. + +Lemma perm_inv_bounded_S : forall n f k, + perm_inv (S n) f k < S n. +Proof. + intros n f k. + induction n; simpl. + - bdestructΩ (f 0 =? k). + - bdestruct (f (S n) =? k); [|transitivity (S n); [apply IHn|]]. + all: apply Nat.lt_succ_diag_r. +Qed. + +Lemma perm_inv_bounded : forall n f, + perm_bounded n (perm_inv n f). +Proof. + induction n. + - easy. + - intros. + apply perm_inv_bounded_S. +Qed. + +#[export] Hint Resolve perm_inv_bounded_S perm_inv_bounded : perm_bounded_db. + +Lemma perm_inv_is_linv_of_injective : forall n f, + perm_inj n f -> + forall k, k < n -> perm_inv n f (f k) = k. +Proof. + intros n f Hinj k Hk. + induction n. + - easy. + - simpl. + bdestruct (f n =? f k). + + apply Hinj; lia. + + assert (k <> n) by (intros Heq; subst; easy). + apply IHn; [auto|]. + assert (k <> n) by (intros Heq; subst; easy). + lia. +Qed. + +Lemma perm_inv_is_rinv_of_surjective' : forall n f k, + (exists l, l < n /\ f l = k) -> + f (perm_inv n f k) = k. +Proof. + intros n f k. + induction n. + - intros []; easy. + - intros [l [Hl Hfl]]. + simpl. + bdestruct (f n =? k); [easy|]. + apply IHn. + exists l. + split; [|easy]. + bdestruct (l =? n); [subst; easy|]. + lia. +Qed. + +Lemma perm_inv_is_rinv_of_surjective : forall n f, + perm_surj n f -> forall k, k < n -> + f (perm_inv n f k) = k. +Proof. + intros n f Hsurj k Hk. + apply perm_inv_is_rinv_of_surjective', Hsurj, Hk. +Qed. + +Lemma perm_inv_is_linv_of_permutation : forall n f, + permutation n f -> + forall k, k < n -> perm_inv n f (f k) = k. +Proof. + intros n f Hperm. + apply perm_inv_is_linv_of_injective, permutation_is_injective, Hperm. +Qed. + +Lemma perm_inv_is_rinv_of_permutation : forall n f, + permutation n f -> + forall k, k < n -> f (perm_inv n f k) = k. +Proof. + intros n f Hperm k Hk. + apply perm_inv_is_rinv_of_surjective', (permutation_is_surjective _ _ Hperm _ Hk). +Qed. + +Lemma perm_inv_is_inv_of_surjective_injective_bounded : forall n f, + perm_surj n f -> perm_inj n f -> perm_bounded n f -> + (forall k, k < n -> + f k < n /\ perm_inv n f k < n /\ perm_inv n f (f k) = k /\ f (perm_inv n f k) = k). +Proof. + intros n f Hsurj Hinj Hbounded. + intros k Hk; repeat split. + - apply Hbounded, Hk. + - apply perm_inv_bounded, Hk. + - rewrite perm_inv_is_linv_of_injective; easy. + - rewrite perm_inv_is_rinv_of_surjective'; [easy|]. + apply Hsurj; easy. +Qed. + +Lemma permutation_iff_surjective : forall n f, + permutation n f <-> perm_surj n f. +Proof. + split. + - apply permutation_is_surjective. + - intros Hsurj. + exists (perm_inv n f). + pose proof (injective_and_bounded_of_surjective n f Hsurj). + apply perm_inv_is_inv_of_surjective_injective_bounded; easy. +Qed. + +Lemma perm_inv_permutation n f : permutation n f -> + permutation n (perm_inv n f). +Proof. + intros Hperm. + exists f. + intros k Hk; repeat split. + - apply perm_inv_bounded, Hk. + - destruct Hperm as [? H]; apply H, Hk. + - rewrite perm_inv_is_rinv_of_permutation; easy. + - rewrite perm_inv_is_linv_of_permutation; easy. +Qed. + +#[export] Hint Resolve perm_inv_permutation : perm_db. + +Lemma permutation_is_bounded n f : permutation n f -> + perm_bounded n f. +Proof. + intros [finv Hfinv] k Hk. + destruct (Hfinv k Hk); easy. +Qed. + +Lemma id_permutation : forall n, + permutation n Datatypes.id. +Proof. + intros. + exists Datatypes.id. + intros. + unfold Datatypes.id. + easy. +Qed. + +Lemma fswap_permutation : forall n f x y, + permutation n f -> + (x < n)%nat -> + (y < n)%nat -> + permutation n (fswap f x y). +Proof. + intros. + replace (fswap f x y) with (f ∘ (fswap (fun i => i) x y))%prg. + apply permutation_compose; auto. + exists (fswap (fun i => i) x y). + intros. unfold fswap. + bdestruct_all; subst; auto. + apply functional_extensionality; intros. + unfold compose, fswap. + bdestruct_all; easy. +Qed. + +Lemma fswap_at_boundary_permutation : forall n f x, + permutation (S n) f -> + (x < S n)%nat -> f x = n -> + permutation n (fswap f x n). +Proof. + intros n f x. + rewrite 2!permutation_iff_surjective. + intros HsurjSn Hx Hfx. + apply fswap_at_boundary_surjective; easy. +Qed. + +(** Well-foundedness of permutations; f k = k for k not in [n]_0 *) +Definition WF_Perm (n : nat) (f : nat -> nat) := + forall k, n <= k -> f k = k. + +Lemma monotonic_WF_Perm n m f : WF_Perm n f -> n <= m -> + WF_Perm m f. +Proof. + intros HWF Hnm k Hk. + apply HWF; lia. +Qed. + +Lemma compose_WF_Perm n f g : WF_Perm n f -> WF_Perm n g -> + WF_Perm n (f ∘ g)%prg. +Proof. + unfold compose. + intros Hf Hg k Hk. + rewrite Hg, Hf; easy. +Qed. + +#[export] Hint Resolve compose_WF_Perm : WF_Perm_db. + +Lemma linv_WF_of_WF {n} {f finv} + (HfWF : WF_Perm n f) (Hinv : (finv ∘ f = idn)%prg) : + WF_Perm n finv. +Proof. + intros k Hk. + rewrite <- (HfWF k Hk). + unfold compose in Hinv. + apply (f_equal_inv k) in Hinv. + rewrite Hinv, (HfWF k Hk). + easy. +Qed. + +Lemma bounded_of_WF_linv {n} {f finv} + (HWF: WF_Perm n f) (Hinv : (finv ∘ f = idn)%prg) : + perm_bounded n f. +Proof. + intros k Hk. + pose proof (linv_WF_of_WF HWF Hinv) as HWFinv. + unfold compose in Hinv. + apply (f_equal_inv k) in Hinv. + bdestruct (f k nat) (n:nat) {finv finv'} + (Hf: permutation n f) (HfWF : WF_Perm n f) + (Hfinv : (finv ∘ f = idn)%prg) (Hfinv' : (finv' ∘ f = idn)%prg) : + finv = finv'. +Proof. + apply functional_extensionality; intros k. + pose proof (linv_WF_of_WF HfWF Hfinv) as HfinvWF. + pose proof (linv_WF_of_WF HfWF Hfinv') as Hfinv'WF. + bdestruct (n <=? k). + - rewrite HfinvWF, Hfinv'WF; easy. + - destruct Hf as [fi Hfi]. + specialize (Hfi k H). + unfold compose in Hfinv, Hfinv'. + apply (f_equal_inv (fi k)) in Hfinv, Hfinv'. + replace (f (fi k)) with k in * by easy. + rewrite Hfinv, Hfinv'. + easy. +Qed. + +Lemma permutation_monotonic_of_WF f m n : (m <= n)%nat -> + permutation m f -> WF_Perm m f -> + permutation n f. +Proof. + intros Hmn [finv_m Hfinv_m] HWF. + exists (fun k => if m <=? k then k else finv_m k). + intros k Hk. + bdestruct (m <=? k). + - rewrite HWF; bdestruct_all; auto. + - specialize (Hfinv_m _ H). + repeat split; bdestruct_all; try easy; lia. +Qed. + + +Definition perm_eq (n : nat) (f g : nat -> nat) := + forall k, k < n -> f k = g k. + +Lemma perm_eq_refl (n : nat) (f : nat -> nat) : + perm_eq n f f. +Proof. + easy. +Qed. + +Lemma perm_eq_sym {n} {f g : nat -> nat} : + perm_eq n f g -> perm_eq n g f. +Proof. + intros H k Hk; symmetry; auto. +Qed. + +Lemma perm_eq_trans {n} {f g h : nat -> nat} : + perm_eq n f g -> perm_eq n g h -> perm_eq n f h. +Proof. + intros Hfg Hgh k Hk; + rewrite Hfg; auto. +Qed. + +Lemma eq_of_WF_perm_eq n f g : WF_Perm n f -> WF_Perm n g -> + perm_eq n f g -> f = g. +Proof. + intros HfWF HgWF Heq. + apply functional_extensionality; intros k. + bdestruct (k perm_bounded n finv -> + perm_eq n (f ∘ finv)%prg idn <-> perm_eq n (finv ∘ f)%prg idn. +Proof. + intros Hperm Hbounded. + split; unfold compose. + - intros Hrinv. + intros k Hk. + apply (permutation_is_injective n f Hperm); try easy. + + apply Hbounded, permutation_is_bounded, Hk. + apply Hperm. + + rewrite Hrinv; [easy|]. + apply (permutation_is_bounded n f Hperm _ Hk). + - intros Hlinv k Hk. + destruct Hperm as [fi Hf]. + destruct (Hf k Hk) as [Hfk [Hfik [Hfifk Hffik]]]. + rewrite <- Hffik. + rewrite Hlinv; easy. +Qed. + +Notation is_perm_rinv n f finv := (perm_eq n (f ∘ finv)%prg idn) (only parsing). +Notation is_perm_linv n f finv := (perm_eq n (finv ∘ f)%prg idn) (only parsing). +Notation is_perm_inv n f finv := + (perm_eq n (f ∘ finv)%prg idn /\ perm_eq n (finv ∘ f)%prg idn) (only parsing). + +Lemma perm_linv_injective_of_surjective n f finv finv' : + perm_surj n f -> is_perm_linv n f finv -> is_perm_linv n f finv' -> + perm_eq n finv finv'. +Proof. + intros Hsurj Hfinv Hfinv' k Hk. + destruct (Hsurj k Hk) as [k' [Hk' Hfk']]. + rewrite <- Hfk'. + unfold compose in *. + rewrite Hfinv, Hfinv'; easy. +Qed. + +Lemma perm_bounded_rinv_injective_of_injective n f finv finv' : + perm_inj n f -> perm_bounded n finv -> perm_bounded n finv' -> + is_perm_rinv n f finv -> is_perm_rinv n f finv' -> + perm_eq n finv finv'. +Proof. + intros Hinj Hbounded Hbounded' Hfinv Hfinv' k Hk. + apply Hinj; auto. + unfold compose in *. + rewrite Hfinv, Hfinv'; easy. +Qed. + +Lemma permutation_inverse_injective n f finv finv' : permutation n f -> + is_perm_inv n f finv -> is_perm_inv n f finv' -> + perm_eq n finv finv'. +Proof. + intros Hperm Hfinv Hfinv'. + eapply perm_linv_injective_of_surjective. + + apply permutation_is_surjective, Hperm. + + destruct (Hfinv); auto. + + destruct (Hfinv'); auto. +Qed. + +Lemma perm_inv_perm_eq_injective (f : nat -> nat) (n : nat) + {finv finv' : nat -> nat} (Hf : permutation n f) : + perm_eq n (finv ∘ f)%prg idn -> + perm_eq n (finv' ∘ f)%prg idn -> + perm_eq n finv finv'. +Proof. + apply perm_linv_injective_of_surjective. + now apply permutation_is_surjective. +Qed. + +Fixpoint for_all_nat_lt (f : nat -> bool) (k : nat) := + match k with + | 0 => true + | S k' => f k' && for_all_nat_lt f k' + end. + +Lemma forall_nat_lt_S (P : forall k : nat, Prop) (n : nat) : + (forall k, k < S n -> P k) <-> P n /\ (forall k, k < n -> P k). +Proof. + split. + - intros Hall. + split; intros; apply Hall; lia. + - intros [Hn Hall]. + intros k Hk. + bdestruct (k=?n); [subst; easy | apply Hall; lia]. +Qed. + +Lemma for_all_nat_ltE {f : nat -> bool} {P : forall k : nat, Prop} + (ref : forall k, reflect (P k) (f k)) : + forall n, (forall k, k < n -> P k) <-> (for_all_nat_lt f n = true). +Proof. + induction n. + - easy. + - rewrite forall_nat_lt_S. + simpl. + rewrite andb_true_iff. + rewrite IHn. + apply and_iff_compat_r. + apply reflect_iff; easy. +Qed. + +Definition perm_inv_is_inv_pred (f : nat -> nat) (n : nat) : Prop := + forall k, k < n -> + f k < n /\ perm_inv n f k < n /\ + perm_inv n f (f k) = k /\ f (perm_inv n f k) = k. + +Definition is_permutation (f : nat -> nat) (n : nat) := + for_all_nat_lt + (fun k => + (f k nat) (n : nat) : + permutation n f <-> perm_inv_is_inv_pred f n. +Proof. + split. + - intros Hperm. + intros k Hk. + repeat split. + + destruct Hperm as [g Hg]; + apply (Hg k Hk). + + apply perm_inv_bounded; easy. + + apply perm_inv_is_linv_of_permutation; easy. + + apply perm_inv_is_rinv_of_permutation; easy. + - intros Hperminv. + exists (perm_inv n f); easy. +Qed. + +Lemma is_permutationE (f : nat -> nat) (n : nat) : + perm_inv_is_inv_pred f n <-> is_permutation f n = true. +Proof. + unfold perm_inv_is_inv_pred, is_permutation. + apply for_all_nat_ltE. + intros k. + apply iff_reflect. + rewrite 3!andb_true_iff. + rewrite 2!Nat.ltb_lt, 2!Nat.eqb_eq, 2!and_assoc. + easy. +Qed. + +Lemma permutation_iff_is_permutation (f : nat -> nat) (n : nat) : + permutation n f <-> is_permutation f n = true. +Proof. + rewrite permutation_iff_perm_inv_is_inv. + apply is_permutationE. +Qed. + +Lemma permutationP (f : nat -> nat) (n : nat) : + reflect (permutation n f) (is_permutation f n). +Proof. + apply iff_reflect, permutation_iff_is_permutation. +Qed. + +Definition permutation_dec (f : nat -> nat) (n : nat) : + {permutation n f} + {~ permutation n f} := + reflect_dec _ _ (permutationP f n). + + +Lemma big_sum_eq_up_to_fswap {G} `{Comm_Group G} + n (v : nat -> G) f x y (Hx : x < n) (Hy : y < n) : + big_sum (fun i => v (f i)) n = + big_sum (fun i => v (fswap f x y i)) n. +Proof. + bdestruct (x =? y); + [apply big_sum_eq_bounded; unfold fswap; intros; + bdestructΩ'|]. + bdestruct (x G) f (Hf : permutation n f) : + big_sum v n = big_sum (fun i => v (f i)) n. +Proof. + intros. + generalize dependent f. + induction n. + reflexivity. + intros f [g Hg]. + destruct (Hg n) as [_ [H1' [_ H2']]]; try lia. + symmetry. + rewrite (big_sum_eq_up_to_fswap _ v _ (g n) n) by auto. + repeat rewrite <- big_sum_extend_r. + rewrite fswap_simpl2. + rewrite H2'. + specialize (IHn (fswap f (g n) n)). + rewrite <- IHn; [easy|]. + apply fswap_at_boundary_permutation; auto. + exists g. auto. +Qed. + +(** vsum terms can be arbitrarily reordered *) +Lemma vsum_reorder : forall {d} n (v : nat -> Vector d) f, + permutation n f -> + big_sum v n = big_sum (fun i => v (f i)) n. +Proof. + intros d n v f Hf. + now apply big_sum_reorder. +Qed. + +(** Some special cases for @big_sum nat nat_is_monoid, to which the + above cannot apply because addition is commutative but does not + form a group. A class Comm_Monoid would generalize this. *) +Lemma Nsum_eq_up_to_fswap + n (v : nat -> nat) f x y (Hx : x < n) (Hy : y < n) : + big_sum (fun i => v (f i)) n = + big_sum (fun i => v (fswap f x y i)) n. +Proof. + bdestruct (x =? y); + [apply big_sum_eq_bounded; unfold fswap; intros; + bdestructΩ'|]. + bdestruct (x nat) f (Hf : permutation n f) : + big_sum v n = big_sum (v ∘ f)%prg n. +Proof. + intros. + generalize dependent f. + induction n. + reflexivity. + intros f Hf. + pose proof Hf as [g Hg]. + destruct (Hg n) as [_ [H1' [_ H2']]]; try lia. + symmetry. + rewrite (Nsum_eq_up_to_fswap _ _ _ (g n) n) by auto. + repeat rewrite <- big_sum_extend_r. + rewrite fswap_simpl2. + unfold compose. + + rewrite H2'. + specialize (IHn (fswap f (g n) n)). + rewrite IHn by + (apply fswap_at_boundary_permutation; auto). + simpl. + f_equal. + apply big_sum_eq_bounded. + intros k Hk. + unfold compose, fswap. + bdestructΩ'. +Qed. + +(** showing every permutation is a sequence of fswaps *) + +(* note the list acts on the left, for example, [s1,s2,...,sk] ⋅ f = s1 ⋅ ( ... ⋅ (sk ⋅ f)) *) +Fixpoint stack_fswaps (f : nat -> nat) (l : list (nat * nat)) := + match l with + | [] => f + | p :: ps => (fswap (Datatypes.id) (fst p) (snd p) ∘ (stack_fswaps f ps))%prg + end. + +Definition WF_fswap_stack n (l : list (nat * nat)) := + forall p, In p l -> (fst p < n /\ snd p < n). + +Lemma WF_fswap_stack_pop : forall n a l, + WF_fswap_stack n (a :: l) -> WF_fswap_stack n l. +Proof. intros. + unfold WF_fswap_stack in *. + intros. + apply H. + right; easy. +Qed. + +Lemma WF_fswap_stack_cons : forall n a l, + fst a < n -> snd a < n -> WF_fswap_stack n l -> WF_fswap_stack n (a :: l). +Proof. intros. + unfold WF_fswap_stack in *. + intros. + destruct H2; subst; auto. +Qed. + +Lemma WF_fswap_miss : forall n l i, + WF_fswap_stack n l -> + n <= i -> + (stack_fswaps Datatypes.id l) i = i. +Proof. induction l. + intros; simpl; easy. + intros; simpl. + unfold compose. + rewrite IHl; auto. + unfold fswap, Datatypes.id; simpl. + destruct (H a). + left; auto. + bdestruct_all; try lia. + apply WF_fswap_stack_pop in H; auto. +Qed. + +Lemma stack_fswaps_permutation : forall {n} (f : nat -> nat) (l : list (nat * nat)), + WF_fswap_stack n l -> + permutation n f -> + permutation n (stack_fswaps f l). +Proof. induction l. + - intros. easy. + - intros. + simpl. + apply permutation_compose. + apply fswap_permutation. + apply id_permutation. + 3 : apply IHl; auto. + 3 : apply WF_fswap_stack_pop in H; auto. + all : apply H; left; easy. +Qed. + +#[export] Hint Resolve stack_fswaps_permutation : perm_db. + +Lemma stack_fswaps_cons : forall (p : nat * nat) (l : list (nat * nat)), + ((stack_fswaps Datatypes.id [p]) ∘ (stack_fswaps Datatypes.id l))%prg = + stack_fswaps Datatypes.id (p :: l). +Proof. intros. + simpl. + rewrite compose_id_right. + easy. +Qed. + +Lemma fswap_comm {A} (f : nat -> A) a b : + fswap f a b = fswap f b a. +Proof. + apply functional_extensionality; intros k. + unfold fswap. + bdestruct_all; now subst. +Qed. + +(* +Theorem all_perms_are_fswap_stacks : forall {n} f, + permutation n f -> + exists l, WF_fswap_stack n l /\ + perm_eq n f (stack_fswaps Datatypes.id l) /\ length l = n. +Proof. +*) + +Definition ordered_real_function n (f : nat -> R) := + forall i j, i < n -> j < n -> i <= j -> (f j <= f i)%R. + +Lemma get_real_function_min : forall {n} (f : nat -> R), + exists n0, (n0 < (S n))%nat /\ (forall i, (i < (S n))%nat -> (f n0 <= f i)%R). +Proof. induction n. + - intros. + exists O; intros. + split; auto. + intros. + destruct i; try lia. + lra. + - intros. + destruct (IHn f) as [n0 [H H0] ]. + destruct (Rlt_le_dec (f n0) (f (S n))). + + exists n0; intros. + split; try lia. + intros. + bdestruct (i =? (S n))%nat; subst. + lra. + apply H0. + bdestruct (n0 R), + exists l, WF_fswap_stack n l /\ + ordered_real_function n (f ∘ (stack_fswaps Datatypes.id l))%prg. +Proof. intros. + generalize dependent f. + induction n. + - intros; exists []. + split; auto. + unfold WF_fswap_stack; intros. + destruct H. + simpl. + unfold ordered_real_function; intros; lia. + - intros. + destruct (@get_real_function_min n f) as [n0 [H H0]]. + destruct (IHn (f ∘ (stack_fswaps Datatypes.id [(n0, n)]))%prg) as [l [H1 H2]]. + exists ((n0, n) :: l). + split. + apply WF_fswap_stack_cons; simpl; auto. + unfold WF_fswap_stack in *; intros. + apply H1 in H3. + lia. + rewrite compose_assoc, stack_fswaps_cons in H2. + unfold ordered_real_function in *. + intros. + bdestruct (j =? n); subst. + simpl. + rewrite <- compose_assoc. + assert (H' : permutation (S n) + (fswap Datatypes.id n0 n ∘ stack_fswaps Datatypes.id l)%prg). + { apply permutation_compose. + apply fswap_permutation; auto. + apply id_permutation. + apply stack_fswaps_permutation. + unfold WF_fswap_stack in *; intros. + apply H1 in H6. + lia. + apply id_permutation. } + unfold compose in *. + destruct H' as [g H6]. + destruct (H6 i); auto. + rewrite (WF_fswap_miss n); auto. + replace (fswap Datatypes.id n0 n n) with n0. + apply H0; easy. + unfold fswap, Datatypes.id. + bdestruct_all; simpl; easy. + bdestruct (j R), + {n0 : nat | (n0 < (S n))%nat /\ (forall i, (i < (S n))%nat -> (f n0 <= f i)%R)}. +Proof. induction n. + - intros. + exists O; intros. + split; auto. + intros. + destruct i; try lia. + lra. + - intros. + destruct (IHn f) as [n0 [H H0] ]. + destruct (Rlt_le_dec (f n0) (f (S n))). + + exists n0; intros. + split; try lia. + intros. + bdestruct (i =? (S n))%nat; subst. + lra. + apply H0. + bdestruct (n0 R), + {l : list (nat * nat) | WF_fswap_stack n l /\ + ordered_real_function n (f ∘ (stack_fswaps Datatypes.id l))%prg}. +Proof. intros. + generalize dependent f. + induction n. + - intros; exists []. + split; auto. + unfold WF_fswap_stack; intros. + destruct H. + simpl. + unfold ordered_real_function; intros; lia. + - intros. + destruct (@get_real_function_min_constructive n f) as [n0 [H H0]]. + destruct (IHn (f ∘ (stack_fswaps Datatypes.id [(n0, n)]))%prg) as [l [H1 H2]]. + exists ((n0, n) :: l). + split. + apply WF_fswap_stack_cons; simpl; auto. + unfold WF_fswap_stack in *; intros. + apply H1 in H3. + lia. + rewrite compose_assoc, stack_fswaps_cons in H2. + unfold ordered_real_function in *. + intros. + bdestruct (j =? n); subst. + simpl. + rewrite <- compose_assoc. + assert (H' : permutation (S n) + (fswap Datatypes.id n0 n ∘ stack_fswaps Datatypes.id l)%prg). + { apply permutation_compose. + apply fswap_permutation; auto. + apply id_permutation. + apply stack_fswaps_permutation. + unfold WF_fswap_stack in *; intros. + apply H1 in H6. + lia. + apply id_permutation. } + unfold compose in *. + destruct H' as [g H6]. + destruct (H6 i); auto. + rewrite (WF_fswap_miss n); auto. + replace (fswap Datatypes.id n0 n n) with n0. + apply H0; easy. + unfold fswap, Datatypes.id. + bdestruct_all; simpl; easy. + bdestruct (j rect_to_polar (polar_to_rect p) = p. @@ -434,3 +496,30 @@ Qed. (****) (*****) (***) + + +Definition Clog (c : C) := + (ln (Cmod c), get_arg c). + +Lemma CexpC_Clog (c : C) (Hc : c <> 0) : + CexpC (Clog c) = c. +Proof. + unfold Clog, CexpC. + cbn. + rewrite exp_ln. + - exact (rect_to_polar_to_rect c Hc). + - apply Cmod_gt_0, Hc. +Qed. + +Lemma Cexp_get_arg_unit (z : C) : Cmod z = 1 -> + Cexp (get_arg z) = z. +Proof. + intros Hmod. + rewrite <- (CexpC_Clog z) at 2 by + (intros H; rewrite H, Cmod_0 in Hmod; lra). + rewrite Cexp_CexpC. + f_equal. + unfold Clog. + rewrite Hmod, ln_1. + reflexivity. +Qed. diff --git a/Prelim.v b/Prelim.v index 2f118ac..d885a93 100644 --- a/Prelim.v +++ b/Prelim.v @@ -100,15 +100,30 @@ Proof. intros. rewrite H. easy. Qed. Lemma f_equal_gen : forall {A B} (f g : A -> B) a b, f = g -> a = b -> f a = g b. Proof. intros. subst. reflexivity. Qed. -(** Currying *) +(** Lists *) -Definition curry {A B C : Type} (f : A * B -> C) : (A -> B -> C) := - fun x y => f (x,y). +Lemma map_nth_eq [A B] (dnew : A) (f : A -> B) (l : list A) (d : B) i : + f dnew = d -> + nth i (map f l) d = f (nth i l dnew). +Proof. + intros <-. + apply map_nth. +Qed. -Definition uncurry {A B C : Type} (f : A -> B -> C) : (A * B -> C) := - fun p => f (fst p) (snd p). +Lemma map_map_nth {A B} (f : A -> B) (l : list (list A)) i : + nth i (map (map f) l) [] = map f (nth i l []). +Proof. + now apply map_nth_eq. +Qed. -(** Lists *) +Lemma map_nth_small [A B] (dnew : A) (f : A -> B) (l : list A) (d : B) i : + i < length l -> + nth i (map f l) d = f (nth i l dnew). +Proof. + intros Hi. + rewrite (nth_indep _ d (f dnew)) by (now rewrite map_length). + apply map_nth. +Qed. Notation "l !! i" := (nth_error l i) (at level 20). @@ -139,11 +154,9 @@ Qed. Lemma repeat_combine : forall A n1 n2 (a : A), List.repeat a n1 ++ List.repeat a n2 = List.repeat a (n1 + n2). -Proof. - induction n1; trivial. - intros. simpl. - rewrite IHn1. - reflexivity. +Proof. + intros. + now rewrite repeat_app. Qed. Lemma rev_repeat : forall A (a : A) n, rev (repeat a n) = repeat a n. @@ -153,7 +166,7 @@ Proof. rewrite (repeat_combine A n 1). rewrite Nat.add_1_r. reflexivity. -Qed. +Qed. Lemma firstn_repeat_le : forall A (a : A) m n, (m <= n)%nat -> firstn m (repeat a n) = repeat a m. @@ -198,21 +211,6 @@ Proof. apply IHm. Qed. -Lemma skipn_length : forall {A} (l : list A) n, - length (skipn n l) = (length l - n)%nat. -Proof. - Transparent skipn. - intros A l. - induction l. - intros [|n]; easy. - intros [|n]. - easy. - simpl. - rewrite IHl. - easy. - Opaque skipn. -Qed. - Lemma nth_firstn : forall {A} i n (l : list A) d, (i < n)%nat -> nth i (firstn n l) d = nth i l d. Proof. @@ -289,6 +287,9 @@ Ltac apply_with_obligations H := [replace g with g'; [apply H|]|]|]|]|]|]|]; trivial end end. +Ltac gen a := + generalize dependent a. + (** From SF - up to five arguments *) Tactic Notation "gen" ident(X1) := generalize dependent X1. diff --git a/Quantum.v b/Quantum.v index b818491..2cf8906 100644 --- a/Quantum.v +++ b/Quantum.v @@ -5,6 +5,7 @@ Require Import Psatz. Require Import Reals. Require Export VecSet. Require Export CauchySchwarz. +Require Import Kronecker. (* Using our (complex, unbounded) matrices, their complex numbers *) @@ -96,12 +97,21 @@ Definition xbasis_plus : Vector 2 := / (√ 2) .* (∣0⟩ .+ ∣1⟩). Definition xbasis_minus : Vector 2 := / (√ 2) .* (∣0⟩ .+ ((-1) .* ∣1⟩)). Definition ybasis_plus : Vector 2 := / (√ 2) .* (∣0⟩ .+ Ci .* ∣1⟩). Definition ybasis_minus : Vector 2 := / (√ 2) .* (∣0⟩ .+ ((-Ci) .* ∣1⟩)). +Definition braminus := / √ 2 .* (⟨0∣ .+ (-1 .* ⟨1∣)). +Definition braplus := / √ 2 .* (⟨0∣ .+ ⟨1∣). Notation "∣+⟩" := xbasis_plus. Notation "∣-⟩" := xbasis_minus. +Notation "⟨+∣" := braplus. +Notation "⟨-∣" := braminus. Notation "∣R⟩" := ybasis_plus. Notation "∣L⟩" := ybasis_minus. +Lemma xbasis_plus_spec : ∣+⟩ = / √ 2 .* (∣0⟩ .+ ∣1⟩). +Proof. reflexivity. Qed. +Lemma xbasis_minus_spec : ∣-⟩ = / √ 2 .* (∣0⟩ .+ (- 1) .* (∣1⟩)). +Proof. reflexivity. Qed. + (* defining the EPR pair *) Definition EPRpair : Vector 4 := / (√ 2) .* (∣0,0⟩ .+ ∣1,1⟩). @@ -160,14 +170,6 @@ Definition sqrtx : Matrix 2 2 := | _, _ => C0 end. -Lemma sqrtx_sqrtx : sqrtx × sqrtx = σx. -Proof. - unfold sqrtx, σx, Mmult. - prep_matrix_equality. - destruct_m_eq; - autorewrite with trig_db C_db; try lca. -Qed. - Definition control {n : nat} (A : Matrix n n) : Matrix (2*n) (2*n) := fun x y => if (x C0 end. -Lemma cnot_eq : cnot = control σx. -Proof. - unfold cnot, control, σx. - solve_matrix. -Qed. - Definition notc : Matrix (2*2) (2*2) := fun x y => match x, y with | 1, 3 => 1%C @@ -285,40 +281,134 @@ Definition Sgate : Matrix 2 2 := phase_shift (PI / 2). Definition Tgate := phase_shift (PI / 4). +(** Well Formedness of Quantum States and Unitaries **) + +Lemma WF_bra0 : WF_Matrix ⟨0∣. Proof. show_wf. Qed. +Lemma WF_bra1 : WF_Matrix ⟨1∣. Proof. show_wf. Qed. +Lemma WF_qubit0 : WF_Matrix ∣0⟩. Proof. show_wf. Qed. +Lemma WF_qubit1 : WF_Matrix ∣1⟩. Proof. show_wf. Qed. +Lemma WF_braket0 : WF_Matrix ∣0⟩⟨0∣. Proof. show_wf. Qed. +Lemma WF_braket1 : WF_Matrix ∣1⟩⟨1∣. Proof. show_wf. Qed. + +#[deprecated(note="Use WF_braket0 instead")] +Notation WF_braqubit0 := WF_braket0 (only parsing). +#[deprecated(note="Use WF_braket1 instead")] +Notation WF_braqubit1 := WF_braket1 (only parsing). + +Lemma WF_bool_to_ket : forall b, WF_Matrix (bool_to_ket b). +Proof. destruct b; show_wf. Qed. +Lemma WF_bool_to_matrix : forall b, WF_Matrix (bool_to_matrix b). +Proof. destruct b; show_wf. Qed. +Lemma WF_bool_to_matrix' : forall b, WF_Matrix (bool_to_matrix' b). +Proof. destruct b; show_wf. Qed. + +Lemma WF_ket : forall n, WF_Matrix (ket n). +Proof. destruct n; simpl; show_wf. Qed. +Lemma WF_bra : forall n, WF_Matrix (bra n). +Proof. destruct n; simpl; show_wf. Qed. + +Lemma WF_bools_to_matrix : forall l, + @WF_Matrix (2^(length l)) (2^(length l)) (bools_to_matrix l). +Proof. + induction l; auto with wf_db. + unfold bools_to_matrix in *; simpl. + apply WF_kron; try rewrite map_length; try lia. + apply WF_bool_to_matrix. + apply IHl. +Qed. + + +Lemma WF_xbasis_plus : WF_Matrix ∣+⟩. Proof. show_wf. Qed. +Lemma WF_xbasis_minus : WF_Matrix ∣-⟩. Proof. show_wf. Qed. +Lemma WF_braplus : WF_Matrix (⟨+∣). Proof. show_wf. Qed. +Lemma WF_braminus : WF_Matrix (⟨-∣). Proof. show_wf. Qed. +Lemma WF_ybasis_plus : WF_Matrix ∣R⟩. Proof. show_wf. Qed. +Lemma WF_ybasis_minus : WF_Matrix ∣L⟩. Proof. show_wf. Qed. + + +#[export] Hint Resolve WF_bra0 WF_bra1 WF_qubit0 WF_qubit1 WF_braket0 WF_braket1 : wf_db. +#[export] Hint Resolve WF_bool_to_ket WF_bool_to_matrix WF_bool_to_matrix' : wf_db. +#[export] Hint Resolve WF_ket WF_bra WF_bools_to_matrix : wf_db. +#[export] Hint Resolve WF_xbasis_plus WF_xbasis_minus WF_braplus WF_braminus + WF_ybasis_plus WF_ybasis_minus : wf_db. + +Lemma WF_EPRpair : WF_Matrix ∣Φ+⟩. Proof. unfold EPRpair. auto with wf_db. Qed. + +#[export] Hint Resolve WF_EPRpair : wf_db. + + +Lemma WF_hadamard : WF_Matrix hadamard. Proof. show_wf. Qed. +Lemma WF_σx : WF_Matrix σx. Proof. show_wf. Qed. +Lemma WF_σy : WF_Matrix σy. Proof. show_wf. Qed. +Lemma WF_σz : WF_Matrix σz. Proof. show_wf. Qed. +Lemma WF_sqrtx : WF_Matrix sqrtx. Proof. show_wf. Qed. +Lemma WF_cnot : WF_Matrix cnot. Proof. show_wf. Qed. +Lemma WF_notc : WF_Matrix notc. Proof. show_wf. Qed. +Lemma WF_swap : WF_Matrix swap. Proof. show_wf. Qed. + +Lemma WF_rotation : forall θ ϕ λ, WF_Matrix (rotation θ ϕ λ). Proof. intros. show_wf. Qed. +Lemma WF_x_rotation : forall θ, WF_Matrix (x_rotation θ). Proof. intros. show_wf. Qed. +Lemma WF_y_rotation : forall θ, WF_Matrix (y_rotation θ). Proof. intros. show_wf. Qed. +Lemma WF_phase : forall ϕ, WF_Matrix (phase_shift ϕ). Proof. intros. show_wf. Qed. + +Lemma WF_Sgate : WF_Matrix Sgate. Proof. show_wf. Qed. +Lemma WF_Tgate: WF_Matrix Tgate. Proof. show_wf. Qed. + +Lemma WF_control : forall (n : nat) (U : Matrix n n), + WF_Matrix U -> WF_Matrix (control U). +Proof. + intros n U WFU. + unfold control, WF_Matrix in *. + intros x y [Hx | Hy]; + bdestruct (x apply WF_phase : wf_db. +#[export] Hint Extern 2 (WF_Matrix (control _)) => apply WF_control : wf_db. + +Lemma sqrtx_sqrtx : sqrtx × sqrtx = σx. +Proof. + prep_matrix_equivalence. + by_cell; lca. +Qed. + +Lemma cnot_eq : cnot = control σx. +Proof. + prep_matrix_equivalence. + now by_cell. +Qed. + Lemma x_rotation_pi : x_rotation PI = -Ci .* σx. Proof. - unfold σx, x_rotation, scale. - prep_matrix_equality. - destruct_m_eq; - autorewrite with trig_db C_db; - reflexivity. + prep_matrix_equivalence. + unfold x_rotation. + autorewrite with trig_db. + by_cell; lca. Qed. Lemma y_rotation_pi : y_rotation PI = -Ci .* σy. Proof. - unfold σy, y_rotation, scale. - prep_matrix_equality. - destruct_m_eq; - autorewrite with trig_db C_db; - try reflexivity. + prep_matrix_equivalence. + unfold y_rotation. + autorewrite with trig_db. + by_cell; lca. Qed. Lemma hadamard_rotation : rotation (PI/2) 0 PI = hadamard. Proof. - unfold hadamard, rotation. - prep_matrix_equality. - destruct_m_eq; try reflexivity; - unfold Cexp; apply injective_projections; simpl; - autorewrite with R_db; - autorewrite with trig_db; - autorewrite with R_db; - try reflexivity. - all: rewrite Rmult_assoc; - replace (/2 * /2)%R with (/4)%R by lra; - repeat rewrite <- Rdiv_unfold; - autorewrite with trig_db; - rewrite sqrt2_div2; - lra. + prep_matrix_equivalence. + unfold rotation, hadamard. + replace (PI / 2 / 2)%R with (PI / 4)%R by lra. + autorewrite with trig_db. + (* autorewrite with R_db. *) + rewrite Rplus_0_l, Cexp_PI, Cexp_0, Cmult_1_l. + rewrite <- RtoC_opp, <- !RtoC_mult, <- RtoC_div. + by_cell; lca. Qed. Lemma pauli_x_rotation : rotation PI 0 PI = σx. @@ -384,16 +474,11 @@ Qed. Lemma I_rotation : rotation 0 0 0 = I 2. Proof. - unfold I, rotation. - prep_matrix_equality. - destruct_m_eq; try reflexivity; - unfold Cexp; apply injective_projections; simpl; - autorewrite with R_db; - autorewrite with trig_db; - autorewrite with R_db; - try reflexivity. - bdestruct (x =? y); bdestruct (S (S x) - A × swap × swap = A. -Proof. - intros. - rewrite Mmult_assoc. - rewrite swap_swap. - Msimpl. - reflexivity. -Qed. - -#[global] Hint Rewrite swap_swap swap_swap_r using (auto 100 with wf_db): Q_db. (* TODO: move these swap lemmas to Permutation.v? *) @@ -566,167 +667,284 @@ Eval compute in ((swap_two 1 0 1) 0 0)%nat. Eval compute in (print_matrix (swap_two 1 0 2)). *) -(** Well Formedness of Quantum States and Unitaries **) +Lemma swap_eq_kron_comm : swap = kron_comm 2 2. +Proof. solve_matrix_fast_with idtac reflexivity. Qed. -Lemma WF_bra0 : WF_Matrix ⟨0∣. Proof. show_wf. Qed. -Lemma WF_bra1 : WF_Matrix ⟨1∣. Proof. show_wf. Qed. -Lemma WF_qubit0 : WF_Matrix ∣0⟩. Proof. show_wf. Qed. -Lemma WF_qubit1 : WF_Matrix ∣1⟩. Proof. show_wf. Qed. -Lemma WF_braket0 : WF_Matrix ∣0⟩⟨0∣. Proof. show_wf. Qed. -Lemma WF_braket1 : WF_Matrix ∣1⟩⟨1∣. Proof. show_wf. Qed. +Lemma swap_swap : swap × swap = I (2*2). +Proof. rewrite swap_eq_kron_comm. apply kron_comm_mul_inv. Qed. -#[deprecated(note="Use WF_braket0 instead")] -Notation WF_braqubit0 := WF_braket0 (only parsing). -#[deprecated(note="Use WF_braket1 instead")] -Notation WF_braqubit1 := WF_braket1 (only parsing). +Lemma swap_swap_r : forall (A : Matrix (2*2) (2*2)), + WF_Matrix A -> + A × swap × swap = A. +Proof. + intros. + rewrite Mmult_assoc. + rewrite swap_swap. + now apply Mmult_1_r. +Qed. -Lemma WF_bool_to_ket : forall b, WF_Matrix (bool_to_ket b). -Proof. destruct b; show_wf. Qed. -Lemma WF_bool_to_matrix : forall b, WF_Matrix (bool_to_matrix b). -Proof. destruct b; show_wf. Qed. -Lemma WF_bool_to_matrix' : forall b, WF_Matrix (bool_to_matrix' b). -Proof. destruct b; show_wf. Qed. +#[global] Hint Rewrite swap_swap swap_swap_r using (auto 100 with wf_db): Q_db. -Lemma WF_ket : forall n, WF_Matrix (ket n). -Proof. destruct n; simpl; show_wf. Qed. -Lemma WF_bra : forall n, WF_Matrix (bra n). -Proof. destruct n; simpl; show_wf. Qed. +Lemma braplus_transpose_ketplus : +⟨+∣⊤ = ∣+⟩. +Proof. solve_matrix_fast. Qed. -Lemma WF_bools_to_matrix : forall l, - @WF_Matrix (2^(length l)) (2^(length l)) (bools_to_matrix l). -Proof. - induction l; auto with wf_db. - unfold bools_to_matrix in *; simpl. - apply WF_kron; try rewrite map_length; try lia. - apply WF_bool_to_matrix. - apply IHl. +Lemma braminus_transpose_ketminus : +⟨-∣⊤ = ∣-⟩. +Proof. solve_matrix_fast. Qed. + +Lemma Mmultplus0 : + ⟨+∣ × ∣0⟩ = / (√2)%R .* I 1. +Proof. + unfold braplus. + rewrite Mscale_mult_dist_l. + rewrite Mmult_plus_distr_r. + rewrite Mmult00. + rewrite Mmult10. + lma. Qed. +Lemma Mmult0plus : + ⟨0∣ × ∣+⟩ = / (√2)%R .* I 1. +Proof. + unfold xbasis_plus. + rewrite Mscale_mult_dist_r. + rewrite Mmult_plus_distr_l. + rewrite Mmult00. + rewrite Mmult01. + lma. +Qed. -Lemma WF_xbasis_plus : WF_Matrix ∣+⟩. Proof. show_wf. Qed. -Lemma WF_xbasis_minus : WF_Matrix ∣-⟩. Proof. show_wf. Qed. -Lemma WF_ybasis_plus : WF_Matrix ∣R⟩. Proof. show_wf. Qed. -Lemma WF_ybasis_minus : WF_Matrix ∣L⟩. Proof. show_wf. Qed. +Lemma Mmultplus1 : + ⟨+∣ × ∣1⟩ = / (√2)%R .* I 1. +Proof. + unfold braplus. + rewrite Mscale_mult_dist_l. + rewrite Mmult_plus_distr_r. + rewrite Mmult01. + rewrite Mmult11. + lma. +Qed. +Lemma Mmult1plus : + ⟨1∣ × ∣+⟩ = / (√2)%R .* I 1. +Proof. + unfold xbasis_plus. + rewrite Mscale_mult_dist_r. + rewrite Mmult_plus_distr_l. + rewrite Mmult10. + rewrite Mmult11. + lma. +Qed. -#[export] Hint Resolve WF_bra0 WF_bra1 WF_qubit0 WF_qubit1 WF_braket0 WF_braket1 WF_braqubit0 WF_braqubit1 : wf_db. -#[export] Hint Resolve WF_bool_to_ket WF_bool_to_matrix WF_bool_to_matrix' : wf_db. -#[export] Hint Resolve WF_ket WF_bra WF_bools_to_matrix : wf_db. -#[export] Hint Resolve WF_xbasis_plus WF_xbasis_minus WF_ybasis_plus WF_ybasis_minus : wf_db. +Lemma Mmultminus0 : + ⟨-∣ × ∣0⟩ = / (√2)%R .* I 1. +Proof. + unfold braminus. + rewrite Mscale_mult_dist_l. + rewrite Mmult_plus_distr_r. + rewrite Mmult00. + rewrite Mscale_mult_dist_l. + rewrite Mmult10. + lma. +Qed. -Lemma WF_EPRpair : WF_Matrix ∣Φ+⟩. Proof. unfold EPRpair. auto with wf_db. Qed. +Lemma Mmult0minus : + ⟨0∣ × ∣-⟩ = / (√2)%R .* I 1. +Proof. + unfold xbasis_minus. + rewrite Mscale_mult_dist_r. + rewrite Mmult_plus_distr_l. + rewrite Mmult00. + rewrite Mscale_mult_dist_r. + rewrite Mmult01. + lma. +Qed. -#[export] Hint Resolve WF_EPRpair : wf_db. +Lemma Mmultminus1 : + ⟨-∣ × ∣1⟩ = - / (√2)%R .* I 1. +Proof. + unfold braminus. + rewrite Mscale_mult_dist_l. + rewrite Mmult_plus_distr_r. + rewrite Mmult01. + rewrite Mscale_mult_dist_l. + rewrite Mmult11. + lma. +Qed. +Lemma Mmult1minus : + ⟨1∣ × ∣-⟩ = - / (√2)%R .* I 1. +Proof. + unfold xbasis_minus. + rewrite Mscale_mult_dist_r. + rewrite Mmult_plus_distr_l. + rewrite Mmult10. + rewrite Mscale_mult_dist_r. + rewrite Mmult11. + lma. +Qed. -Lemma WF_hadamard : WF_Matrix hadamard. Proof. show_wf. Qed. -Lemma WF_σx : WF_Matrix σx. Proof. show_wf. Qed. -Lemma WF_σy : WF_Matrix σy. Proof. show_wf. Qed. -Lemma WF_σz : WF_Matrix σz. Proof. show_wf. Qed. -Lemma WF_cnot : WF_Matrix cnot. Proof. show_wf. Qed. -Lemma WF_notc : WF_Matrix notc. Proof. show_wf. Qed. -Lemma WF_swap : WF_Matrix swap. Proof. show_wf. Qed. +Lemma Mmultminusminus : + ⟨-∣ × ∣-⟩ = I 1. +Proof. + prep_matrix_equivalence. + unfold braminus. + unfold xbasis_minus. + distribute_scale. + group_radicals. + by_cell; lca. +Qed. -Lemma WF_rotation : forall θ ϕ λ, WF_Matrix (rotation θ ϕ λ). Proof. intros. show_wf. Qed. -Lemma WF_phase : forall ϕ, WF_Matrix (phase_shift ϕ). Proof. intros. show_wf. Qed. +Lemma Mmultplusminus : + ⟨+∣ × ∣-⟩ = Zero. +Proof. + prep_matrix_equivalence. + unfold braplus, xbasis_minus. + distribute_scale. + group_radicals. + by_cell; lca. +Qed. -Lemma WF_Sgate : WF_Matrix Sgate. Proof. show_wf. Qed. -Lemma WF_Tgate: WF_Matrix Tgate. Proof. show_wf. Qed. +Lemma Mmultminusplus : + ⟨-∣ × ∣+⟩ = Zero. +Proof. + prep_matrix_equivalence. + unfold xbasis_plus, braminus. + distribute_scale. + group_radicals. + by_cell; lca. +Qed. -Lemma WF_control : forall (n : nat) (U : Matrix n n), - WF_Matrix U -> WF_Matrix (control U). +Lemma Mmultplusplus : + ⟨+∣ × ∣+⟩ = I 1. Proof. - intros n U WFU. - unfold control, WF_Matrix in *. - intros x y [Hx | Hy]; - bdestruct (x apply WF_phase : wf_db. -#[export] Hint Extern 2 (WF_Matrix (control _)) => apply WF_control : wf_db. +Lemma bra0transpose : + ⟨0∣⊤ = ∣0⟩. +Proof. solve_matrix_fast. Qed. + +Lemma bra1transpose : + ⟨1∣⊤ = ∣1⟩. +Proof. solve_matrix_fast. Qed. +Lemma ket0transpose : + ∣0⟩⊤ = ⟨0∣. +Proof. solve_matrix_fast. Qed. +Lemma ket1transpose : + ∣1⟩⊤ = ⟨1∣. +Proof. solve_matrix_fast. Qed. -(* how to make this proof shorter? *) +Lemma Mplus_plus_minus : ∣+⟩ .+ ∣-⟩ = (√2)%R .* ∣0⟩. +Proof. + solve_matrix_fast_with + (unfold xbasis_plus, xbasis_minus; autounfold with U_db) + (cbn; try lca; Csimpl; C_field; lca). +Qed. + +Lemma Mplus_plus_minus_opp : ∣+⟩ .+ -1 .* ∣-⟩ = (√2)%R .* ∣1⟩. +Proof. + solve_matrix_fast_with + (unfold xbasis_plus, xbasis_minus; autounfold with U_db) + (cbn; try lca; Csimpl; C_field; lca). +Qed. + +Lemma Mplus_minus_plus : ∣-⟩ .+ ∣+⟩ = (√2)%R .* ∣0⟩. +Proof. + solve_matrix_fast_with + (unfold xbasis_plus, xbasis_minus; autounfold with U_db) + (cbn; try lca; Csimpl; C_field; lca). +Qed. + +Lemma Mplus_minus_opp_plus : -1 .* ∣-⟩ .+ ∣+⟩ = (√2)%R .* ∣1⟩. +Proof. + solve_matrix_fast_with + (unfold xbasis_plus, xbasis_minus; autounfold with U_db) + (cbn; try lca; Csimpl; C_field; lca). +Qed. + +Lemma Mplus_0_1 : ∣0⟩ .+ ∣1⟩ = (√2)%R .* ∣+⟩. +Proof. + solve_matrix_fast_with + (unfold xbasis_plus, xbasis_minus; autounfold with U_db) + (cbn; try lca; Csimpl; C_field; lca). +Qed. + +Lemma Mplus_0_1_opp : ∣0⟩ .+ -1 .* ∣1⟩ = (√2)%R .* ∣-⟩. +Proof. + solve_matrix_fast_with + (unfold xbasis_plus, xbasis_minus; autounfold with U_db) + (cbn; try lca; Csimpl; C_field; lca). +Qed. + +Lemma Mplus_1_0 : ∣1⟩ .+ ∣0⟩ = (√2)%R .* ∣+⟩. +Proof. + solve_matrix_fast_with + (unfold xbasis_plus, xbasis_minus; autounfold with U_db) + (cbn; try lca; Csimpl; C_field; lca). +Qed. + +Lemma Mplus_1_opp_0 : -1 .* ∣1⟩ .+ ∣0⟩ = (√2)%R .* ∣-⟩. +Proof. + solve_matrix_fast_with + (unfold xbasis_plus, xbasis_minus; autounfold with U_db) + (cbn; try lca; Csimpl; C_field; lca). +Qed. + +Lemma σz_decomposition : σz = ∣0⟩⟨0∣ .+ -1 .* ∣1⟩⟨1∣. +Proof. solve_matrix_fast. Qed. + +(* It may be possible to remove the WF_Matrix B hypothesis *) Lemma direct_sum_decomp : forall (m n o p : nat) (A B : Matrix m n), - WF_Matrix A -> WF_Matrix B -> + WF_Matrix A -> WF_Matrix B -> A .⊕ B = ∣0⟩⟨0∣ ⊗ A .+ ∣1⟩⟨1∣ ⊗ B. -Proof. - intros. +Proof. + intros m n o p A B HA HB. + prep_matrix_equivalence. unfold direct_sum, kron, Mplus. - prep_matrix_equality. - bdestruct_all; try lia; simpl. - - repeat (rewrite Nat.div_small, Nat.mod_small; try easy); lca. - - rewrite H; auto. - destruct n. rewrite H, H0; try lca; try (right; lia). - rewrite (Nat.div_small x m), (Nat.mod_small x m); try easy. - replace (y / S n)%nat with (1 + (y - S n)/S n)%nat. - unfold Mmult, adjoint; simpl. - destruct (fst (Nat.divmod (y - S n) n 0 n)); try lca. - rewrite <- Nat.div_add_l; auto. - replace ((1 * S n + (y - S n)))%nat with y by lia; easy. - - rewrite H; auto. - destruct m. rewrite H, H0; try lca; try (left; lia). - rewrite (Nat.div_small y n), (Nat.mod_small y n); try easy. - replace (x / S m)%nat with (1 + (x - S m)/S m)%nat. - unfold Mmult, adjoint; simpl. - destruct (fst (Nat.divmod (x - S m) m 0 m)); try lca. - rewrite <- Nat.div_add_l; auto. - replace ((1 * S m + (x - S m)))%nat with x by lia; easy. - - destruct n; destruct m. - try (rewrite H, H0, H0; try lca); - try (left; lia); try (right; lia). - try (rewrite H, H0, H0; try lca); - try (left; lia); try (right; lia). - try (rewrite H, H0, H0; try lca); - try (left; lia); try (right; lia). - bdestruct (x - S m WF_Unitary (control A). Proof. intros n A H. destruct H as [WF U]. split; auto with wf_db. - unfold control, adjoint, Mmult, I. - prep_matrix_equality. - simpl. - bdestructΩ (x =? y). - - subst; simpl. - rewrite big_sum_sum. - bdestructΩ (y A x (y - n)%nat ^* * A x (y - n)%nat)). - ++ unfold control, adjoint, Mmult, I in U. - rewrite Nat.add_0_r. - eapply (equal_f) in U. - eapply (equal_f) in U. - rewrite U. - rewrite Nat.eqb_refl. simpl. - bdestructΩ (y - n A z (x - n)%nat ^* * A z (y - n)%nat)). - ++ unfold control, adjoint, Mmult, I in U. - rewrite Nat.add_0_r. - eapply (equal_f) in U. - eapply (equal_f) in U. - rewrite U. - bdestructΩ (x - n =? y - n). - simpl. - easy. - ++ apply functional_extensionality. intros z. - bdestructΩ (n + z WF_Unitary (A⊤). Proof. - intros. - simpl. - split. - + destruct H; auto with wf_db. - + destruct H. - replace ((A⊤)†) with ((A†)⊤). - rewrite <- Mmult_transpose. - rewrite Minv_flip; auto with wf_db. - prep_matrix_equality. - unfold transpose, I. - bdestruct_all; easy. - prep_matrix_equality. - unfold transpose, adjoint. - easy. + intros n A [HA U]. + split; [auto_wf|]. + change ((A⊤)†) with ((A†)⊤). + rewrite <- Mmult_transpose. + rewrite Minv_flip; auto with wf_db. + apply id_transpose_eq. Qed. Lemma adjoint_unitary : forall n (A : Matrix n n), WF_Unitary A -> WF_Unitary (A†). Proof. - intros. - simpl. - split. - + destruct H; auto with wf_db. - + unfold WF_Unitary in *. - rewrite adjoint_involutive. - destruct H as [H H0]. - apply Minv_left in H0 as [_ S]; - auto with wf_db. + intros n A [HA U]. + split; [auto_wf|]. + rewrite adjoint_involutive. + rewrite Minv_flip; auto with wf_db. Qed. Lemma cnot_unitary : WF_Unitary cnot. Proof. - split. - apply WF_cnot. - unfold Mmult, I. - prep_matrix_equality. - do 4 (try destruct x; try destruct y; try lca). - replace ((S (S (S (S x))) WF_Unitary a) -> WF_Unitary (⨂ ls). -Proof. intros. induction ls as [| h]. - - simpl. apply id_unitary. - - simpl. - apply kron_unitary. - apply (H h). - left. easy. - apply IHls. - intros. - apply H. right. easy. +Proof. + intros. induction ls as [| h]. + - apply id_unitary. + - simpl. + apply kron_unitary. + + apply (H h). + left. easy. + + apply IHls. + intros. + apply H. right. easy. Qed. (* alternate version for more general length application *) Lemma big_kron_unitary' : forall (n m : nat) (ls : list (Square n)), - length ls = m -> (forall a, In a ls -> WF_Unitary a) -> @WF_Unitary (n^m) (⨂ ls). -Proof. intros; subst. induction ls as [| h]. - - simpl. apply id_unitary. - - simpl. - apply kron_unitary. - apply (H0 h). - left. easy. - apply IHls. - intros. - apply H0. right. easy. + length ls = m -> (forall a, In a ls -> WF_Unitary a) -> + @WF_Unitary (n^m) (⨂ ls). +Proof. + intros; subst. + now apply big_kron_unitary. Qed. Lemma Mmult_unitary : forall (n : nat) (A : Square n) (B : Square n), @@ -1165,8 +1182,7 @@ Lemma Mmult_unitary : forall (n : nat) (A : Square n) (B : Square n), WF_Unitary (A × B). Proof. intros n A B [WFA UA] [WFB UB]. - split. - auto with wf_db. + split; [auto_wf|]. Msimpl. rewrite Mmult_assoc. rewrite <- (Mmult_assoc A†). @@ -1185,31 +1201,49 @@ Proof. auto with wf_db. distribute_adjoint. distribute_scale. - rewrite UA, Cmult_comm, H. - lma. + rewrite UA, Cmult_comm, H. + apply Mscale_1_l. Qed. Lemma pad1_unitary : forall (n : nat) (c : C) (A : Square n), WF_Unitary A -> (c * c ^*)%C = C1 -> WF_Unitary (pad1 A c). -Proof. intros. - split. - destruct H; auto with wf_db. - rewrite pad1_adjoint, <- pad1_mult. - destruct H. - rewrite H1, Cmult_comm, H0, pad1_I. - easy. +Proof. + intros n c A [WFA U] Hc. + split; [auto_wf|]. + rewrite pad1_adjoint, <- pad1_mult. + rewrite U, Cmult_comm, Hc, pad1_I. + easy. Qed. - #[export] Hint Resolve transpose_unitary adjoint_unitary cnot_unitary notc_unitary id_unitary : unit_db. #[export] Hint Resolve swap_unitary zero_not_unitary kron_unitary big_kron_unitary big_kron_unitary' Mmult_unitary scale_unitary pad1_unitary : unit_db. +Lemma unitary_conj_trans_real {n} {A : Matrix n n} (HA : WF_Unitary A) i j : + snd ((A †%M × A) i j) = 0. +Proof. + destruct HA as [_ Heq]. + apply (f_equal_inv i) in Heq. + apply (f_equal_inv j) in Heq. + rewrite Heq. + unfold I. + Modulus.bdestructΩ'. +Qed. + +Lemma unitary_conj_trans_nonneg {n} {A : Matrix n n} + (HA : WF_Unitary A) i j : + 0 <= fst ((A †%M × A) i j). +Proof. + rewrite (proj2 HA). + unfold I; Modulus.bdestructΩ'; simpl; lra. +Qed. + + Lemma hadamard_st : hadamard ⊤ = hadamard. -Proof. solve_matrix. Qed. +Proof. solve_matrix_fast. Qed. Lemma adjoint_transpose_comm : forall m n (A : Matrix m n), A † ⊤ = A ⊤ †. @@ -1227,45 +1261,28 @@ Definition hermitian {n} (A : Square n) := Lemma I_hermitian : forall {n}, hermitian (I n). -Proof. intros. - apply id_adjoint_eq. +Proof. + intros. + apply id_adjoint_eq. Qed. Lemma hadamard_hermitian : hermitian hadamard. -Proof. - prep_matrix_equality. - repeat (try destruct x; try destruct y; try lca; trivial). -Qed. +Proof. solve_matrix_fast. Qed. Lemma σx_hermitian : hermitian σx. -Proof. - prep_matrix_equality. - repeat (try destruct x; try destruct y; try lca; trivial). -Qed. +Proof. solve_matrix_fast. Qed. Lemma σy_hermitian : hermitian σy. -Proof. - prep_matrix_equality. - repeat (try destruct x; try destruct y; try lca; trivial). -Qed. +Proof. solve_matrix_fast. Qed. Lemma σz_hermitian : hermitian σz. -Proof. - prep_matrix_equality. - repeat (try destruct x; try destruct y; try lca; trivial). -Qed. +Proof. solve_matrix_fast. Qed. Lemma cnot_hermitian : hermitian cnot. -Proof. - prep_matrix_equality. - repeat (try destruct x; try destruct y; try lca; trivial). -Qed. +Proof. solve_matrix_fast. Qed. Lemma swap_hermitian : hermitian swap. -Proof. - prep_matrix_equality. - repeat (try destruct x; try destruct y; try lca; trivial). -Qed. +Proof. solve_matrix_fast. Qed. @@ -1274,83 +1291,77 @@ Qed. Lemma plus_hermitian : forall {n} (A B : Square n), hermitian A -> hermitian B -> hermitian (A .+ B). -Proof. intros n A B H H0. - unfold hermitian. - distribute_adjoint. - rewrite H, H0. - easy. +Proof. + intros n A B HA HB. + unfold hermitian. + distribute_adjoint. + rewrite HA, HB. + easy. Qed. Lemma adjoint_hermitian : forall {n} (A : Square n), hermitian A -> hermitian A†. -Proof. intros. - unfold hermitian. - do 2 rewrite H. - easy. +Proof. + intros n A HA. + unfold hermitian. + rewrite 2!HA. + easy. Qed. Lemma unit_conj_hermitian : forall {n} (A U : Square n), hermitian A -> WF_Unitary U -> hermitian (U × A × U†). -Proof. intros. - destruct H0; auto with wf_db. - unfold hermitian. - rewrite 2 Mmult_adjoint, adjoint_involutive, Mmult_assoc, H. - easy. +Proof. + intros n A U HA [HUWF HU]. + unfold hermitian in *. + rewrite 2 Mmult_adjoint, adjoint_involutive, Mmult_assoc, HA. + easy. Qed. Lemma AAadjoint_hermitian : forall {m n} (A : Matrix m n), hermitian (A × A†). -Proof. intros. - unfold hermitian. - rewrite Mmult_adjoint, adjoint_involutive. - easy. +Proof. + intros. + unfold hermitian. + rewrite Mmult_adjoint, adjoint_involutive. + easy. Qed. Lemma AadjointA_hermitian : forall {m n} (A : Matrix m n), hermitian (A† × A). -Proof. intros. - unfold hermitian. - rewrite Mmult_adjoint, adjoint_involutive. - easy. +Proof. + intros. + unfold hermitian. + rewrite Mmult_adjoint, adjoint_involutive. + easy. Qed. Lemma control_adjoint : forall n (U : Square n), (control U)† = control (U†). Proof. intros n U. - unfold control, adjoint. - prep_matrix_equality. - rewrite Nat.eqb_sym. - bdestruct (y =? x). - - subst. - bdestruct (x hermitian (control A). Proof. - intros n A H. + intros n A HA. unfold hermitian in *. rewrite control_adjoint. - rewrite H. + rewrite HA. easy. -Qed. +Qed. Lemma phase_adjoint : forall ϕ, (phase_shift ϕ)† = phase_shift (-ϕ). Proof. intros ϕ. + prep_matrix_equivalence. unfold phase_shift, adjoint. - prep_matrix_equality. - destruct_m_eq; try lca. - unfold Cexp, Cconj. - rewrite cos_neg, sin_neg. - easy. + rewrite <- Cexp_conj_neg. + by_cell; lca. Qed. @@ -1359,17 +1370,17 @@ Qed. Lemma rotation_adjoint : forall θ ϕ λ, (rotation θ ϕ λ)† = rotation (-θ) (-λ) (-ϕ). Proof. intros. + prep_matrix_equivalence. unfold rotation, adjoint. - prep_matrix_equality. - destruct_m_eq; try lca; - unfold Cexp, Cconj; - apply injective_projections; simpl; - try rewrite <- Ropp_plus_distr; - autorewrite with R_db; - autorewrite with trig_db; - try rewrite (Rplus_comm λ ϕ); - autorewrite with R_db; - reflexivity. + rewrite Rdiv_opp_l, cos_neg, sin_neg. + by_cell; rewrite ?Cconj_mult_distr; Csimpl. + - reflexivity. + - rewrite Cexp_conj_neg. + lca. + - rewrite Cconj_opp, Cexp_conj_neg. + lca. + - rewrite Cexp_conj_neg. + now rewrite Ropp_plus_distr, Rplus_comm. Qed. Lemma braket0_hermitian : hermitian ∣0⟩⟨0∣. Proof. lma. Qed. @@ -1512,22 +1523,24 @@ Lemma braqubit1_sa : ∣1⟩⟨1∣† = ∣1⟩⟨1∣. Proof. lma. Qed. Lemma phase_0 : phase_shift 0 = I 2. Proof. - unfold phase_shift, I. + prep_matrix_equivalence. + unfold phase_shift, I. rewrite Cexp_0. - solve_matrix. + by_cell; reflexivity. Qed. Lemma phase_2pi : phase_shift (2 * PI) = I 2. + prep_matrix_equivalence. unfold phase_shift, I. rewrite Cexp_2PI. - solve_matrix. + by_cell; reflexivity. Qed. Lemma phase_pi : phase_shift PI = σz. Proof. unfold phase_shift, σz. rewrite Cexp_PI. - replace (RtoC (-1)) with (Copp (RtoC 1)) by lca. + rewrite <- RtoC_opp. reflexivity. Qed. @@ -1542,7 +1555,8 @@ Qed. Lemma phase_mul : forall θ θ', phase_shift θ × phase_shift θ' = phase_shift (θ + θ'). Proof. - intros. solve_matrix. rewrite Cexp_add. reflexivity. + intros. + solve_matrix_fast_with idtac (cbn; rewrite 1?Cexp_add; lca). Qed. Lemma phase_pow : forall θ n, n ⨉ (phase_shift θ) = phase_shift (INR n * θ). @@ -1609,63 +1623,68 @@ Definition positive_semidefinite {n} (A : Square n) : Prop := Lemma positive_semidefinite_AAadjoint : forall {m n} (A : Matrix m n), positive_semidefinite (A × A†). -Proof. intros. - unfold positive_semidefinite. - intros. - replace (((z) † × (A × (A) †) × z) 0%nat 0%nat) with (⟨ A† × z, A† × z ⟩). - apply Rle_ge; apply inner_product_ge_0. - unfold inner_product. - distribute_adjoint. - rewrite adjoint_involutive, 3 Mmult_assoc. - easy. +Proof. + intros. + unfold positive_semidefinite. + intros. + replace (((z) † × (A × (A) †) × z) 0%nat 0%nat) with (⟨ A† × z, A† × z ⟩). + - apply Rle_ge; apply inner_product_ge_0. + - unfold inner_product. + distribute_adjoint. + rewrite adjoint_involutive, 3 Mmult_assoc. + easy. Qed. Lemma positive_semidefinite_AadjointA : forall {m n} (A : Matrix m n), positive_semidefinite (A† × A). -Proof. intros. - assert (H' := (positive_semidefinite_AAadjoint A†)). - rewrite adjoint_involutive in H'. - easy. +Proof. + intros. + assert (H' := (positive_semidefinite_AAadjoint A†)). + rewrite adjoint_involutive in H'. + easy. Qed. Lemma positive_semidefinite_unitary_conj : forall {n} (A U : Square n), WF_Unitary U -> positive_semidefinite A -> positive_semidefinite (U† × A × U). -Proof. intros. - unfold positive_semidefinite in *. - intros. - replace ((z) † × ((U) † × A × U) × z) with (((z) † × (U†)) × A × (U × z)). - rewrite <- Mmult_adjoint. - apply H0. - destruct H; auto with wf_db. - repeat rewrite Mmult_assoc; easy. +Proof. + intros n A U [HUWF HU] HA. + unfold positive_semidefinite in *. + intros z Hz. + replace ((z) † × ((U) † × A × U) × z) with (((z) † × (U†)) × A × (U × z)) + by (now rewrite !Mmult_assoc). + rewrite <- Mmult_adjoint. + apply HA. + auto_wf. Qed. Lemma positive_semidefinite_unitary_conj_conv : forall {n} (A U : Square n), WF_Unitary U -> positive_semidefinite (U† × A × U) -> positive_semidefinite A. -Proof. intros. - unfold positive_semidefinite in *. - intros. - replace ((z) † × A × z) with (((U† × z)† × (U† × A × U) × (U† × z))). - apply H0. - destruct H; auto with wf_db. - distribute_adjoint. - rewrite adjoint_involutive. - destruct H. - apply Minv_flip in H2; auto with wf_db. - rewrite 3 Mmult_assoc, <- (Mmult_assoc _ _ z), H2, Mmult_1_l; auto. - rewrite <- 2 (Mmult_assoc U), H2, <- 2 Mmult_assoc, Mmult_1_r; auto with wf_db. -Qed. - -Lemma pure_psd : forall (n : nat) (ϕ : Vector n), (WF_Matrix ϕ) -> positive_semidefinite (ϕ × ϕ†). +Proof. + intros n A U [HUWF HU] HA. + unfold positive_semidefinite in *. + intros z Hz. + replace ((z) † × A × z) with (((U† × z)† × (U† × A × U) × (U† × z))). + - apply HA; auto_wf. + - distribute_adjoint. + rewrite adjoint_involutive. + apply Minv_flip in HU; [|auto_wf..]. + rewrite 3 Mmult_assoc, <- (Mmult_assoc _ _ z), HU, Mmult_1_l by auto. + rewrite <- 2 (Mmult_assoc U), HU, <- 2 Mmult_assoc, Mmult_1_r + by auto with wf_db. + reflexivity. +Qed. + +Lemma pure_psd : forall (n : nat) (ϕ : Vector n), (WF_Matrix ϕ) -> + positive_semidefinite (ϕ × ϕ†). Proof. intros n ϕ WFϕ z WFZ. - repeat rewrite Mmult_assoc. + rewrite !Mmult_assoc. remember (ϕ† × z) as ψ. - repeat rewrite <- Mmult_assoc. + rewrite <- !Mmult_assoc. rewrite <- (adjoint_involutive _ _ ϕ). rewrite <- Mmult_adjoint. rewrite <- Heqψ. @@ -1674,26 +1693,26 @@ Proof. rewrite Rplus_0_l. unfold Rminus. rewrite Ropp_involutive. - replace (fst (z 1%nat 0%nat) * fst (z 1%nat 0%nat))%R with ((fst (z 1%nat 0%nat))²) by easy. - replace (snd (z 1%nat 0%nat) * snd (z 1%nat 0%nat))%R with ((snd (z 1%nat 0%nat))²) by easy. + rewrite <- 2!Rsqr_def. apply Rle_ge. apply Rplus_le_le_0_compat; apply Rle_0_sqr. Qed. Lemma braket0_psd : positive_semidefinite ∣0⟩⟨0∣. -Proof. apply pure_psd. auto with wf_db. Qed. +Proof. apply pure_psd. auto_wf. Qed. Lemma braket1_psd : positive_semidefinite ∣1⟩⟨1∣. -Proof. apply pure_psd. auto with wf_db. Qed. +Proof. apply pure_psd. auto_wf. Qed. Lemma H0_psd : positive_semidefinite (hadamard × ∣0⟩⟨0∣ × hadamard). Proof. - repeat rewrite Mmult_assoc. - rewrite <- hadamard_hermitian at 2. + rewrite !Mmult_assoc. + rewrite <- hadamard_hermitian_rw. rewrite <- Mmult_adjoint. - repeat rewrite <- Mmult_assoc. + rewrite <- !Mmult_assoc. + rewrite hadamard_hermitian_rw. apply pure_psd. - auto with wf_db. + auto_wf. Qed. @@ -1713,33 +1732,39 @@ Definition Pure_State {n} (ρ : Density n) : Prop := Inductive Mixed_State {n} : Matrix n n -> Prop := | Pure_S : forall ρ, Pure_State ρ -> Mixed_State ρ -| Mix_S : forall (p : R) ρ1 ρ2, 0 < p < 1 -> Mixed_State ρ1 -> Mixed_State ρ2 -> - Mixed_State (p .* ρ1 .+ (1-p)%R .* ρ2). +| Mix_S : forall (p : R) ρ1 ρ2, 0 < p < 1 -> + Mixed_State ρ1 -> Mixed_State ρ2 -> + Mixed_State (p .* ρ1 .+ (1-p)%R .* ρ2). Lemma WF_Pure : forall {n} (ρ : Density n), Pure_State ρ -> WF_Matrix ρ. -Proof. intros. destruct H as [φ [[WFφ IP1] Eρ]]. rewrite Eρ. auto with wf_db. Qed. +Proof. intros. destruct H as [φ [[WFφ IP1] Eρ]]. rewrite Eρ. auto_wf. Qed. #[export] Hint Resolve WF_Pure : wf_db. Lemma WF_Mixed : forall {n} (ρ : Density n), Mixed_State ρ -> WF_Matrix ρ. -Proof. induction 1; auto with wf_db. Qed. +Proof. intros n p H. induction H; auto_wf. Qed. #[export] Hint Resolve WF_Mixed : wf_db. Lemma pure0 : Pure_State ∣0⟩⟨0∣. -Proof. exists ∣0⟩. intuition. split. auto with wf_db. solve_matrix. Qed. +Proof. exists ∣0⟩. split; [|easy]. split; [auto_wf|]. apply Mmult00. Qed. Lemma pure1 : Pure_State ∣1⟩⟨1∣. -Proof. exists ∣1⟩. intuition. split. auto with wf_db. solve_matrix. Qed. +Proof. exists ∣1⟩. split; [|easy]. split; [auto_wf|]. apply Mmult11. Qed. -Lemma pure_id1 : Pure_State (I 1). -Proof. exists (I 1). split. split. auto with wf_db. solve_matrix. solve_matrix. Qed. +Lemma pure_id1 : Pure_State (I 1). +Proof. + exists (I 1). + split; [|now Msimpl_light]. + split; [auto_wf|]. + now Msimpl_light. +Qed. -Lemma pure_dim1 : forall (ρ : Square 1), Pure_State ρ -> ρ = I 1. +Lemma pure_dim1 : forall (ρ : Square 1), Pure_State ρ -> ρ = I 1. Proof. - intros. + intros p H. assert (H' := H). apply WF_Pure in H'. destruct H as [φ [[WFφ IP1] Eρ]]. - apply Minv_flip in IP1; auto with wf_db. + apply Minv_flip in IP1; [|auto_wf..]. rewrite Eρ; easy. Qed. @@ -1748,7 +1773,7 @@ Lemma pure_state_unitary_pres : forall {n} (ϕ : Vector n) (U : Square n), Proof. unfold Pure_State_Vector. intros n ϕ U [H H0] [H1 H2]. - split; auto with wf_db. + split; [auto_wf|]. distribute_adjoint. rewrite Mmult_assoc, <- (Mmult_assoc _ U), H2, Mmult_1_l; auto. Qed. @@ -1760,17 +1785,17 @@ Proof. intros n m ϕ ψ [WFu Pu] [WFv Pv]. split. - apply WF_kron; auto. - - Msimpl. rewrite Pu, Pv. Msimpl. easy. + - Msimpl. rewrite Pu, Pv. apply kron_1_r. Qed. - + Lemma pure_state_kron : forall m n (ρ : Square m) (φ : Square n), Pure_State ρ -> Pure_State φ -> Pure_State (ρ ⊗ φ). Proof. intros m n ρ φ [u [? Eρ]] [v [? Eφ]]. exists (u ⊗ v). split. - - apply pure_state_vector_kron; auto. - - Msimpl. subst. easy. + - apply pure_state_vector_kron; assumption. + - Msimpl. now subst. Qed. Lemma mixed_state_kron : forall m n (ρ : Square m) (φ : Square n), @@ -1788,24 +1813,21 @@ Proof. apply Mix_S; easy. Qed. -Lemma pure_state_trace_1 : forall {n} (ρ : Density n), Pure_State ρ -> trace ρ = 1. +Lemma pure_state_trace_1 : forall {n} (ρ : Density n), Pure_State ρ -> + trace ρ = 1. Proof. intros n ρ [u [[WFu Uu] E]]. subst. - clear -Uu. + clear WFu. unfold trace. unfold Mmult, adjoint in *. simpl in *. - match goal with - [H : ?f = ?g |- _] => assert (f O O = g O O) by (rewrite <- H; easy) - end. - unfold I in H; simpl in H. - rewrite <- H. - apply big_sum_eq. - apply functional_extensionality. - intros x. - rewrite Cplus_0_l, Cmult_comm. - easy. + apply (f_equal (fun A => A O O)) in Uu. + unfold I in Uu; simpl in Uu. + rewrite <- Uu. + apply big_sum_eq_bounded. + intros k Hk. + lca. Qed. Lemma mixed_state_trace_1 : forall {n} (ρ : Density n), Mixed_State ρ -> trace ρ = 1. @@ -1822,63 +1844,54 @@ Qed. (* The following two lemmas say that for any mixed states, the elements along the diagonal are real numbers in the [0,1] interval. *) -Lemma mixed_state_diag_in01 : forall {n} (ρ : Density n) i , Mixed_State ρ -> - 0 <= fst (ρ i i) <= 1. +Lemma mixed_state_diag_in01 : forall {n} (ρ : Density n) i, Mixed_State ρ -> + 0 <= fst (ρ i i) <= 1. Proof. intros. - induction H. - - destruct H as [φ [[WFφ IP1] Eρ]]. - destruct (lt_dec i n). - 2: rewrite Eρ; unfold Mmult, adjoint; simpl; rewrite WFφ; simpl; [lra|lia]. - rewrite Eρ. - unfold Mmult, adjoint in *. - simpl in *. - rewrite Rplus_0_l. - match goal with - [H : ?f = ?g |- _] => assert (f O O = g O O) by (rewrite <- H; easy) - end. - unfold I in H. simpl in H. clear IP1. - match goal with - [ H : ?x = ?y |- _] => assert (H': fst x = fst y) by (rewrite H; easy); clear H - end. - simpl in H'. - rewrite <- H'. - split. - + unfold Rminus. rewrite <- Ropp_mult_distr_r. rewrite Ropp_involutive. - rewrite <- Rplus_0_r at 1. - apply Rplus_le_compat; apply Rle_0_sqr. - + match goal with - [ |- ?x <= fst (big_sum ?f ?m)] => specialize (big_sum_member_le f n) as res - end. - simpl in *. - unfold Rminus in *. - rewrite <- Ropp_mult_distr_r. - rewrite Ropp_mult_distr_l. - apply res with (x := i); trivial. - intros x. - unfold Rminus. rewrite <- Ropp_mult_distr_l. rewrite Ropp_involutive. - rewrite <- Rplus_0_r at 1. - apply Rplus_le_compat; apply Rle_0_sqr. + induction H as [u Hu | p]. + - bdestruct (i u j j)); easy. - simpl. - repeat rewrite Rmult_0_l. - repeat rewrite Rminus_0_r. + rewrite !Rmult_0_l, !Rminus_0_r. split. - assert (0 <= p * fst (ρ1 i i)). - apply Rmult_le_pos; lra. - assert (0 <= (1 - p) * fst (ρ2 i i)). - apply Rmult_le_pos; lra. - lra. - assert (p * fst (ρ1 i i) <= p)%R. - rewrite <- Rmult_1_r. - apply Rmult_le_compat_l; lra. - assert ((1 - p) * fst (ρ2 i i) <= (1-p))%R. - rewrite <- Rmult_1_r. - apply Rmult_le_compat_l; lra. - lra. -Qed. - -Lemma mixed_state_diag_real : forall {n} (ρ : Density n) i , Mixed_State ρ -> - snd (ρ i i) = 0. + + assert (0 <= p * fst (ρ1 i i)) + by (apply Rmult_le_pos; lra). + assert (0 <= (1 - p) * fst (ρ2 i i)) + by (apply Rmult_le_pos; lra). + lra. + + assert (p * fst (ρ1 i i) <= p)%R + by (rewrite <- Rmult_1_r; + apply Rmult_le_compat_l; lra). + assert ((1 - p) * fst (ρ2 i i) <= (1-p))%R + by (rewrite <- Rmult_1_r; + apply Rmult_le_compat_l; lra). + lra. +Qed. + +Lemma mixed_state_diag_real : forall {n} (ρ : Density n) i , + Mixed_State ρ -> + snd (ρ i i) = 0. Proof. intros. induction H. @@ -1893,7 +1906,7 @@ Proof. lra. Qed. -Lemma mixed_dim1 : forall (ρ : Square 1), Mixed_State ρ -> ρ = I 1. +Lemma mixed_dim1 : forall (ρ : Square 1), Mixed_State ρ -> ρ = I 1. Proof. intros. induction H. @@ -1915,9 +1928,8 @@ Definition WF_Superoperator {m n} (f : Superoperator m n) := Definition super {m n} (M : Matrix m n) : Superoperator n m := fun ρ => M × ρ × M†. -Lemma super_I : forall n ρ, - WF_Matrix ρ -> - super (I n) ρ = ρ. +Lemma super_I : forall n ρ, WF_Matrix ρ -> + super (I n) ρ = ρ. Proof. intros. unfold super. @@ -1939,14 +1951,16 @@ Lemma super_outer_product : forall m (φ : Matrix m 1) (U : Matrix m m), Proof. intros. unfold super, outer_product. autorewrite with M_db Q_db. - repeat rewrite Mmult_assoc. reflexivity. + rewrite !Mmult_assoc. + reflexivity. Qed. -Definition compose_super {m n p} (g : Superoperator n p) (f : Superoperator m n) - : Superoperator m p := fun ρ => g (f ρ). +Definition compose_super {m n p} + (g : Superoperator n p) (f : Superoperator m n) : Superoperator m p := + fun ρ => g (f ρ). -Lemma WF_compose_super : forall m n p (g : Superoperator n p) (f : Superoperator m n) - (ρ : Square m), +Lemma WF_compose_super : forall m n p + (g : Superoperator n p) (f : Superoperator m n) (ρ : Square m), WF_Matrix ρ -> (forall A, WF_Matrix A -> WF_Matrix (f A)) -> (forall A, WF_Matrix A -> WF_Matrix (g A)) -> @@ -1960,10 +1974,10 @@ Qed. Lemma compose_super_correct : forall {m n p} - (g : Superoperator n p) (f : Superoperator m n), - WF_Superoperator g -> - WF_Superoperator f -> - WF_Superoperator (compose_super g f). + (g : Superoperator n p) (f : Superoperator m n), + WF_Superoperator g -> + WF_Superoperator f -> + WF_Superoperator (compose_super g f). Proof. intros m n p g f pf_g pf_f. unfold WF_Superoperator. @@ -1976,12 +1990,13 @@ Definition sum_super {m n} (f g : Superoperator m n) : Superoperator m n := fun ρ => (1/2)%R .* f ρ .+ (1 - 1/2)%R .* g ρ. Lemma sum_super_correct : forall m n (f g : Superoperator m n), - WF_Superoperator f -> WF_Superoperator g -> WF_Superoperator (sum_super f g). + WF_Superoperator f -> WF_Superoperator g -> + WF_Superoperator (sum_super f g). Proof. intros m n f g wf_f wf_g ρ pf_ρ. unfold sum_super. - set (wf_f' := wf_f _ pf_ρ). - set (wf_g' := wf_g _ pf_ρ). + pose proof (wf_f _ pf_ρ). + pose proof (wf_g _ pf_ρ). apply (Mix_S (1/2) (f ρ) (g ρ)); auto. lra. Qed. @@ -2050,7 +2065,7 @@ Proof. unfold compose_super, super. apply functional_extensionality. intros ρ. rewrite Mmult_adjoint. - repeat rewrite Mmult_assoc. + rewrite !Mmult_assoc. reflexivity. Qed. @@ -2065,31 +2080,28 @@ Ltac Qsimpl := try restore_dims; autorewrite with M_db_light M_db Q_db. (* Tests and Lemmas about swap matrices *) (****************************************) -Lemma swap_spec : forall (q q' : Vector 2), WF_Matrix q -> - WF_Matrix q' -> - swap × (q ⊗ q') = q' ⊗ q. +Lemma swap_spec : forall (q q' : Vector 2), + WF_Matrix q -> WF_Matrix q' -> + swap × (q ⊗ q') = q' ⊗ q. Proof. intros q q' WF WF'. - solve_matrix. - - destruct y. lca. - rewrite WF by lia. - rewrite (WF' O (S y)) by lia. - lca. - - destruct y. lca. - rewrite WF by lia. - rewrite (WF' O (S y)) by lia. - lca. - - destruct y. lca. - rewrite WF by lia. - rewrite (WF' 1%nat (S y)) by lia. - lca. - - destruct y. lca. - rewrite WF by lia. - rewrite (WF' 1%nat (S y)) by lia. - lca. + rewrite swap_eq_kron_comm. + rewrite kron_comm_commutes_l by easy. + rewrite kron_comm_1_l. + apply Mmult_1_r; auto_wf. Qed. -#[global] Hint Rewrite swap_spec using (auto 100 with wf_db) : Q_db. +#[global] Hint Rewrite swap_spec using auto_wf : Q_db. + +Lemma swap_transpose : swap ⊤%M = swap. +Proof. now rewrite swap_eq_kron_comm, kron_comm_transpose. Qed. + +Lemma swap_spec' : + swap = ((ket 0 × bra 0) ⊗ (ket 0 × bra 0) .+ (ket 0 × bra 1) ⊗ (ket 1 × bra 0) + .+ (ket 1 × bra 0) ⊗ (ket 0 × bra 1) .+ (ket 1 × bra 1) ⊗ (ket 1 × bra 1)). +Proof. + solve_matrix_fast_with idtac (cbv; lca). +Qed. Example swap_to_0_test_24 : forall (q0 q1 q2 q3 : Vector 2), WF_Matrix q0 -> WF_Matrix q1 -> WF_Matrix q2 -> WF_Matrix q3 -> @@ -2098,16 +2110,20 @@ Proof. intros q0 q1 q2 q3 WF0 WF1 WF2 WF3. unfold swap_to_0, swap_to_0_aux. simpl. - rewrite Mmult_assoc. - repeat rewrite Mmult_assoc. - rewrite (kron_assoc q0 q1) by auto with wf_db. Qsimpl. - replace 4%nat with (2*2)%nat by reflexivity. - repeat rewrite kron_assoc by auto with wf_db. + rewrite !Mmult_assoc. + rewrite (kron_assoc q0 q1) by auto with wf_db. + restore_dims. + rewrite 2!kron_mixed_product, swap_spec, 2!Mmult_1_l, + <- kron_assoc by auto_wf. + restore_dims. + rewrite (kron_assoc (_ ⊗ _)) by auto_wf. + rewrite kron_mixed_product, Mmult_1_l, swap_spec by auto_wf. restore_dims. - rewrite <- (kron_assoc q0 q2) by auto with wf_db. Qsimpl. - rewrite (kron_assoc q2) by auto with wf_db. Qsimpl. - rewrite <- kron_assoc by auto with wf_db. Qsimpl. - repeat rewrite <- kron_assoc by auto with wf_db. + rewrite <- kron_assoc, (kron_assoc q2) by auto_wf. + rewrite 2!kron_mixed_product. + rewrite 2!Mmult_1_l, swap_spec by auto_wf. + restore_dims. + rewrite <- !kron_assoc by auto_wf. reflexivity. Qed. @@ -2126,8 +2142,7 @@ Lemma swap_0_2 : swap_two 3 0 2 = (I 2 ⊗ swap) × (swap ⊗ I 2) × (I 2 ⊗ s Proof. unfold swap_two. simpl. - Qsimpl. - reflexivity. + now rewrite kron_1_r. Qed. (* @@ -2157,12 +2172,12 @@ Proof. rewrite (kron_assoc q0 q1) by auto with wf_db. simpl. restore_dims. - replace 4%nat with (2*2)%nat by reflexivity. - Qsimpl. - rewrite <- kron_assoc by auto with wf_db. + rewrite 2!kron_mixed_product, 2!Mmult_1_l, swap_spec, <- kron_assoc by easy. + restore_dims. + rewrite kron_assoc by auto_wf. + rewrite kron_mixed_product, Mmult_1_l, swap_spec by auto_wf. restore_dims. - repeat rewrite (kron_assoc _ q1) by auto with wf_db. - Qsimpl. + rewrite <- kron_assoc by auto_wf. reflexivity. Qed. diff --git a/RealAux.v b/RealAux.v index 817687e..c699476 100644 --- a/RealAux.v +++ b/RealAux.v @@ -18,6 +18,7 @@ Lemma Rlt_minus_l : forall a b c,(a - c < b <-> a < b + c). Proof. intros. lra. Lemma Rle_minus_r : forall a b c,(a <= b - c <-> a + c <= b). Proof. intros. lra. Qed. Lemma Rminus_le_0 : forall a b, a <= b <-> 0 <= b - a. Proof. intros. lra. Qed. Lemma Rminus_lt_0 : forall a b, a < b <-> 0 < b - a. Proof. intros. lra. Qed. +Lemma Ropp_lt_0 : forall x : R, x < 0 -> 0 < -x. Proof. intros. lra. Qed. (* Automation *) @@ -43,6 +44,21 @@ Proof. intros. unfold Rdiv. rewrite Rinv_mult; trivial. lra. Qed. Lemma Rdiv_cancel : forall r r1 r2 : R, r1 = r2 -> r / r1 = r / r2. Proof. intros. rewrite H. reflexivity. Qed. +(* FIXME: TODO: Remove; included in later versions of stdlib *) +Lemma Rdiv_0_r : forall r, r / 0 = 0. +Proof. intros. rewrite Rdiv_unfold, Rinv_0, Rmult_0_r. reflexivity. Qed. + +Lemma Rdiv_0_l : forall r, 0 / r = 0. +Proof. intros. rewrite Rdiv_unfold, Rmult_0_l. reflexivity. Qed. + +Lemma Rdiv_opp_l : forall r1 r2, - r1 / r2 = - (r1 / r2). +Proof. intros. lra. Qed. + +Lemma Rsqr_def : forall r, r² = r * r. +Proof. intros r. easy. Qed. + +(* END FIXME *) + Lemma Rsum_nonzero : forall r1 r2 : R, r1 <> 0 \/ r2 <> 0 -> r1 * r1 + r2 * r2 <> 0. Proof. intros. @@ -55,6 +71,54 @@ Proof. - specialize (pow_nonzero r2 2 H). intros NZ. lra. Qed. +Lemma Rmult_le_impl_le_disj_nonneg (x y z w : R) + (Hz : 0 <= z) (Hw : 0 <= w) : + x * y <= z * w -> x <= z \/ y <= w. +Proof. + destruct (Rle_or_lt x z), (Rle_or_lt y w); + [left + right; easy..|]. + assert (Hx : 0 < x) by lra. + assert (Hy : 0 < y) by lra. + intros Hfasle. + destruct Hz, Hw; enough (z * w < x * y) by lra; + [|subst; rewrite ?Rmult_0_l, ?Rmult_0_r; + apply Rmult_lt_0_compat; easy..]. + pose proof (Rmult_lt_compat_r w z x) as Ht1. + pose proof (Rmult_lt_compat_l x w y) as Ht2. + lra. +Qed. + +Lemma Rle_pow_le_nonneg (x y : R) (Hx : 0 <= x) (Hy : 0 <= y) n : + x ^ (S n) <= y ^ (S n) -> x <= y. +Proof. + induction n; [rewrite !pow_1; easy|]. + change (?z ^ S ?m) with (z * z ^ m). + intros H. + apply Rmult_le_impl_le_disj_nonneg in H; + [|easy|apply pow_le; easy]. + destruct H; auto. +Qed. + +Lemma Rabs_eq_0_iff a : Rabs a = 0 <-> a = 0. +Proof. + split; [|intros ->; apply Rabs_R0]. + unfold Rabs. + destruct (Rcase_abs a); lra. +Qed. + +Lemma Rplus_ge_0_of_ge_Rabs a b : Rabs a <= b -> 0 <= a + b. +Proof. + unfold Rabs. + destruct (Rcase_abs a); lra. +Qed. + +Lemma Rpow_0_l n : O <> n -> (0 ^ n = 0)%R. +Proof. + intros H. + destruct n; [easy|]. + simpl; now field_simplify. +Qed. + Lemma Rpow_le1: forall (x : R) (n : nat), 0 <= x <= 1 -> x ^ n <= 1. Proof. intros; induction n. @@ -64,7 +128,7 @@ Proof. apply Rmult_le_compat; try lra. apply pow_le; lra. Qed. - + (* The other side of Rle_pow, needed below *) Lemma Rle_pow_le1: forall (x : R) (m n : nat), 0 <= x <= 1 -> (m <= n)%nat -> x ^ n <= x ^ m. @@ -153,6 +217,26 @@ Proof. intros. assert (H' := H). unfold sqrt in H. destruct (Rcase_abs x). - rewrite <- (sqrt_def x); try rewrite <- H'; lra. Qed. +Lemma sqrt_eq_iff_eq_sqr (r s : R) (Hs : 0 < s) : + sqrt r = s <-> r = pow s 2. +Proof. + split. + - destruct (Rcase_abs r) as [Hr | Hr]; + [rewrite sqrt_neg_0; lra|]. + intros H. + apply (f_equal (fun i => pow i 2)) in H. + rewrite pow2_sqrt in H; lra. + - intros ->. + rewrite sqrt_pow2; lra. +Qed. + +Lemma sqrt_eq_1_iff_eq_1 (r : R) : + sqrt r = 1 <-> r = 1. +Proof. + rewrite sqrt_eq_iff_eq_sqr by lra. + now rewrite pow1. +Qed. + Lemma lt_ep_helper : forall (ϵ : R), ϵ > 0 <-> ϵ / √ 2 > 0. Proof. intros; split; intros. @@ -166,7 +250,36 @@ Proof. intros; split; intros. apply sqrt2_neq_0. Qed. +Lemma pow2_mono_le_inv a b : 0 <= b -> a^2 <= b^2 -> a <= b. +Proof. + intros Hb Hab. + destruct (Rlt_le_dec a 0); [lra|]. + simpl in *. + rewrite 2!Rmult_1_r in *. + destruct (Rlt_le_dec b a); [|easy]. + pose proof (Rmult_lt_compat_l a b a). + pose proof (Rmult_le_compat b a b b). + lra. +Qed. + +Lemma sqrt_ge a b : + a^2 <= b -> a <= √ b. +Proof. + intros Hab. + pose proof (pow2_ge_0 a). + apply pow2_mono_le_inv. + - apply sqrt_pos. + - rewrite pow2_sqrt by lra. + easy. +Qed. +Lemma sqrt_ge_abs a b : + a^2 <= b -> Rabs a <= √ b. +Proof. + intros Hab. + apply sqrt_ge. + now rewrite pow2_abs. +Qed. (** Defining 2-adic valuation of an integer and properties *) @@ -690,3 +803,516 @@ Proof. induction n as [| n']. apply IHn'; intros; apply H; lia. apply H; lia. Qed. + +Lemma Rsum_ge_0_on (n : nat) (f : nat -> R) : + (forall k, (k < n)%nat -> 0 <= f k) -> + 0 <= big_sum f n. +Proof. + induction n; [simpl; lra|]. + intros Hf. + specialize (IHn ltac:(intros;apply Hf;lia)). + simpl. + specialize (Hf n ltac:(lia)). + lra. +Qed. + +Lemma Rsum_nonneg_le (n : nat) (f : nat -> R) a : + (forall k, (k < n)%nat -> 0 <= f k) -> + big_sum f n <= a -> + forall k, (k < n)%nat -> f k <= a. +Proof. + intros Hfge0 Hle. + induction n; [easy|]. + specialize (IHn ltac:(intros; apply Hfge0; lia) + ltac:(simpl in Hle; specialize (Hfge0 n ltac:(lia)); lra)). + simpl in Hle. + pose proof (Rsum_ge_0_on n f ltac:(intros; apply Hfge0; lia)). + intros k Hk. + bdestruct (k =? n). + - subst; lra. + - apply IHn; lia. +Qed. + +Lemma Rsum_nonneg_ge_any n (f : nat -> R) k (Hk : (k < n)%nat) : + (forall i, (i < n)%nat -> 0 <= f i) -> + f k <= big_sum f n. +Proof. + intros Hle. + induction n; [easy|]. + bdestruct (k =? n). + - subst. + simpl. + pose proof (Rsum_ge_0_on n f ltac:(intros;apply Hle;lia)). + lra. + - pose proof (Hle n ltac:(lia)). + simpl. + apply Rle_trans with (big_sum f n). + + apply IHn; [lia | intros; apply Hle; lia]. + + lra. +Qed. + +Lemma pos_IZR (p : positive) : IZR (Z.pos p) = INR (Pos.to_nat p). +Proof. + induction p. + - rewrite IZR_POS_xI. + rewrite Pos2Nat.inj_xI. + rewrite IHp. + rewrite S_INR, mult_INR. + simpl. + lra. + - rewrite IZR_POS_xO. + rewrite Pos2Nat.inj_xO. + rewrite IHp. + rewrite mult_INR. + simpl. + lra. + - reflexivity. +Qed. + +Lemma INR_to_nat (z : Z) : (0 <= z)%Z -> + INR (Z.to_nat z) = IZR z. +Proof. + intros Hz. + destruct z; [reflexivity| | ]. + - simpl. + rewrite pos_IZR. + reflexivity. + - lia. +Qed. + +(* For compatibility pre-8.18 FIXME: Remove when we deprecate those versions *) +Lemma Int_part_spec_compat : forall r z, r - 1 < IZR z <= r -> z = Int_part r. +Proof. + unfold Int_part; intros r z [Hle Hlt]; apply Z.add_move_r, tech_up. + - rewrite <-(Rplus_0_r r), <-(Rplus_opp_l 1), <-Rplus_assoc, plus_IZR. + now apply Rplus_lt_compat_r. + - now rewrite plus_IZR; apply Rplus_le_compat_r. +Qed. +Lemma Rplus_Int_part_frac_part_compat : forall r, r = IZR (Int_part r) + frac_part r. +Proof. now unfold frac_part; intros r; rewrite Rplus_minus. Qed. + +Notation Rplus_Int_part_frac_part := Rplus_Int_part_frac_part_compat. +Notation Int_part_spec := Int_part_spec_compat. + + +Lemma lt_S_Int_part r : r < IZR (1 + Int_part r). +Proof. + rewrite (Rplus_Int_part_frac_part r) at 1. + rewrite Z.add_comm, plus_IZR. + pose proof (base_fp r). + lra. +Qed. + +Lemma lt_Int_part (r s : R) : (Int_part r < Int_part s)%Z -> r < s. +Proof. + intros Hlt. + apply Rlt_le_trans with (IZR (Int_part s)); + [apply Rlt_le_trans with (IZR (1 + Int_part r))|]. + - apply lt_S_Int_part. + - apply IZR_le. + lia. + - apply base_Int_part. +Qed. + +Lemma Int_part_le (r s : R) : r <= s -> (Int_part r <= Int_part s)%Z. +Proof. + intros Hle. + rewrite <- Z.nlt_ge. + intros H%lt_Int_part. + lra. +Qed. + +Lemma IZR_le_iff (z y : Z) : IZR z <= IZR y <-> (z <= y)%Z. +Proof. + split. + - apply le_IZR. + - apply IZR_le. +Qed. + +Lemma IZR_lt_iff (z y : Z) : IZR z < IZR y <-> (z < y)%Z. +Proof. + split. + - apply lt_IZR. + - apply IZR_lt. +Qed. + +Lemma Int_part_IZR (z : Z) : Int_part (IZR z) = z. +Proof. + symmetry. + apply Int_part_spec. + change 1 with (INR (Pos.to_nat 1)). + rewrite <- pos_IZR, <- minus_IZR. + rewrite IZR_le_iff, IZR_lt_iff. + lia. +Qed. + +Lemma Int_part_ge_iff (r : R) (z : Z) : + (z <= Int_part r)%Z <-> (IZR z <= r). +Proof. + split. + - intros Hle. + apply Rle_trans with (IZR (Int_part r)). + + apply IZR_le, Hle. + + apply base_Int_part. + - intros Hle. + rewrite <- (Int_part_IZR z). + apply Int_part_le, Hle. +Qed. + +Lemma Rpower_pos (b c : R) : + 0 < Rpower b c. +Proof. + apply exp_pos. +Qed. + +Lemma ln_nondecreasing x y : 0 < x -> + x <= y -> ln x <= ln y. +Proof. + intros Hx [Hlt | ->]; [|right; reflexivity]. + left. + apply ln_increasing; auto. +Qed. + +Lemma ln_le_inv x y : 0 < x -> 0 < y -> + ln x <= ln y -> x <= y. +Proof. + intros Hx Hy [Hlt | Heq]; + [left; apply ln_lt_inv; auto|]. + right. + apply ln_inv; auto. +Qed. + + +Lemma Rdiv_le_iff a b c (Hb : 0 < b) : + a / b <= c <-> a <= b * c. +Proof. + split. + - intros Hle. + replace a with (b * (a / b)). + 2: { + unfold Rdiv. + rewrite <- Rmult_assoc, (Rmult_comm b a), Rmult_assoc. + rewrite Rinv_r; lra. + } + apply Rmult_le_compat_l; lra. + - intros Hle. + apply Rle_trans with (b * c / b). + + apply Rmult_le_compat_r. + * enough (0 < / b) by lra. + now apply Rinv_0_lt_compat. + * easy. + + rewrite Rmult_comm. + unfold Rdiv. + rewrite Rmult_assoc, Rinv_r; lra. +Qed. + +Lemma div_Rpower_le_of_le (r b c d : R) : + 0 < r -> 1 < b -> 0 < c -> 0 < d -> + ln (r / d) / ln b <= c -> + r / (Rpower b c) <= d. +Proof. + intros Hr Hb Hc Hd Hle. + assert (0 < Rpower b c) by apply Rpower_pos. + rewrite Rdiv_le_iff, Rmult_comm, + <- Rdiv_le_iff by auto. + apply ln_le_inv; + [apply Rdiv_lt_0_compat; lra|auto|]. + unfold Rpower. + rewrite ln_exp. + rewrite Rmult_comm. + rewrite <- Rdiv_le_iff; [auto|]. + rewrite <- ln_1. + apply ln_increasing; lra. +Qed. + + +Section Rnthroot_def. + +Lemma Rpow_1_l n : 1 ^ n = 1. +Proof. + induction n; [easy|]. + simpl; lra. +Qed. + +Lemma Rpow_ge_1_ge n x : 1 <= x -> + x <= x ^ S n. +Proof. + intros Hx. + induction n. + - lra. + - apply Rle_trans with (x ^ S n); [assumption|]. + simpl. + rewrite <- (Rmult_1_l) at 1. + apply Rmult_le_compat; try lra. + apply Rle_trans with x; [lra|apply IHn]. +Qed. + +Lemma Rpow_eq_0_iff n x : x ^ n = 0 <-> (x = 0 /\ n <> O). +Proof. + split. + - intros H. + destruct n; [simpl in *; lra|]. + split; [|easy]. + induction n; [lra|]. + simpl in H. + apply Rmult_integral in H. + destruct H; [|apply IHn]; apply H. + - intros [-> Hn]. + apply Rpow_0_l; lia. +Qed. + +Lemma derivable_Rpow n : derivable (fun x => x ^ n). +Proof. + induction n. + - apply derivable_const. + - change (fun x => x ^ S n) with (id * (fun x => x ^ n))%F. + apply derivable_mult; [|assumption]. + apply derivable_id. +Qed. + +Lemma Rnthroot_exists n : + forall y:R, 0 <= y -> sigT (fun z:R => 0 <= z /\ y = z ^ S n). +Proof. + intros y Hy. + set (f := fun x:R => x ^ S n - y). + assert (Hf0 : f 0 <= 0) + by (unfold f; rewrite Rpow_0_l by easy; lra). + assert (H : continuity f). 1:{ + change f with ((fun x => x ^ S n) - fct_cte y)%F. + apply continuity_minus. + apply derivable_continuous; apply derivable_Rpow. + apply derivable_continuous; apply derivable_const. + } + case (total_order_T y 1); intro. + elim s; intro. + - assert (H0 : 0 <= f 1) by + (unfold f; rewrite Rpow_1_l; lra). + assert (H1 : f 0 * f 1 <= 0). 1: { + rewrite Rmult_comm. + pattern 0 at 2 in |- *. + rewrite <- (Rmult_0_r (f 1)). + apply Rmult_le_compat_l; assumption. + } + assert (X := IVT_cor f 0 1 H (Rlt_le _ _ Rlt_0_1) H1). + elim X; intros t H4. + apply existT with t. + elim H4; intros. + split. + lra. + unfold f in H3. + apply Rminus_diag_uniq_sym; exact H3. + - apply existT with 1. + split. + left; apply Rlt_0_1. + rewrite b; symmetry in |- *; apply Rpow_1_l. + - + assert (H0 : 0 <= f y). 1: { + unfold f. + apply Rplus_le_reg_l with y. + rewrite Rplus_0_r; rewrite Rplus_comm; unfold Rminus in |- *; + rewrite Rplus_assoc; rewrite Rplus_opp_l; rewrite Rplus_0_r. + apply Rpow_ge_1_ge; lra. + } + assert (H1 : f 0 * f y <= 0). 1: { + rewrite Rmult_comm. + pattern 0 at 2 in |- *. + rewrite <- (Rmult_0_r (f y)). + apply Rmult_le_compat_l; assumption. + } + assert (X := IVT_cor f 0 y H Hy H1). + elim X; intros t H4. + apply existT with t. + elim H4; intros. + split. + lra. + unfold f in H3. + apply Rminus_diag_uniq_sym; exact H3. +Qed. + +Definition Rnthroot n x := + match Rcase_abs x with + | left _ => 0 + | right a => + match Rnthroot_exists n x (Rge_le _ _ a) with + | existT _ a b => a + end + end. + +Lemma Rnthroot_positivity n x : 0 <= Rnthroot n x. +Proof. + unfold Rnthroot. + destruct (Rcase_abs x); [lra|]. + destruct (Rnthroot_exists n x (Rge_le x 0 r)). + easy. +Qed. + +Lemma Rpow_Sn_nthroot n x : 0 <= x -> + (Rnthroot n x) ^ S n = x. +Proof. + intros Hx. + unfold Rnthroot. + destruct (Rcase_abs x); [lra|]. + destruct (Rnthroot_exists n x (Rge_le x 0 r)). + easy. +Qed. + +Lemma Rnthroot_0_l x : 0 <= x -> Rnthroot 0 x = x. +Proof. + pose proof (Rpow_Sn_nthroot 0 x). + lra. +Qed. + +Lemma Rnthroot_0_r n : Rnthroot n 0 = 0. +Proof. + pose proof (Rpow_Sn_nthroot n 0 ltac:(lra)). + now rewrite Rpow_eq_0_iff in H. +Qed. + +Lemma Rnthroot_def_alt n x : + Rnthroot n x = + match Rlt_le_dec 0 x with + | left H0x => Rnthroot n x + | right Hxnonpos => 0 + end. +Proof. + destruct (Req_dec x 0). + - subst. + rewrite Rnthroot_0_r. + now destruct (Rlt_le_dec 0 0). + - unfold Rnthroot. + destruct (Rcase_abs x), (Rlt_le_dec 0 x); try lra. + now assert (x = 0) by lra. +Qed. + +Lemma Rnthroot_nonpos n x : x <= 0 -> Rnthroot n x = 0. +Proof. + intros Hx. + destruct (Req_dec x 0). + - subst. + apply Rnthroot_0_r. + - unfold Rnthroot. + destruct (Rcase_abs x); [lra|]. + now assert (x = 0) by lra. +Qed. + +Lemma Rpow_pos n x : 0 <= x -> 0 <= x ^ n. +Proof. + intros Hx. + induction n; [lra|]. + simpl. + replace 0 with (0 * 0) by lra. + apply Rmult_le_compat; lra. +Qed. + +Lemma Rpow_le_compat_r n x y : 0 <= x -> + x <= y -> x ^ n <= y ^ n. +Proof. + intros Hx Hxy. + induction n; [lra|]. + simpl. + pose proof (Rpow_pos n x Hx). + apply Rmult_le_compat; lra. +Qed. + +Lemma Rpow_lt_compat_r n x y : n <> O -> 0 <= x -> + x < y -> x ^ n < y ^ n. +Proof. + intros Hn Hx Hxy. + induction n; [easy|]. + destruct n; [lra|]. + simpl in *. + specialize (IHn ltac:(easy)). + apply Rle_lt_trans with (y * (x * x ^ n)). + - pose proof (Rpow_pos (S n) x Hx); simpl in *. + apply Rmult_le_compat_r; lra. + - apply Rmult_lt_compat_l; lra. +Qed. + +Lemma Rpow_inj_pos n x y : 0 <= x -> 0 <= y -> n <> O -> + x ^ n = y ^ n -> x = y. +Proof. + intros Hx Hy Hn Hpow. + destruct (Rle_lt_dec x y), (Rle_lt_dec y x); try lra. + - pose proof (Rpow_lt_compat_r n x y Hn Hx); lra. + - pose proof (Rpow_lt_compat_r n y x Hn Hy); lra. +Qed. + +Lemma Rnthroot_pow_Sn n x : 0 <= x -> + Rnthroot n (x ^ S n) = x. +Proof. + intros Hx. + apply Rpow_inj_pos with (S n). + - apply Rnthroot_positivity. + - easy. + - easy. + - now rewrite Rpow_Sn_nthroot by (now apply Rpow_pos). +Qed. + + + + +Lemma Rnthroot_mult_distr n x y : (0 <= x \/ 0 <= y) -> + Rnthroot n (x * y) = Rnthroot n x * Rnthroot n y. +Proof. + intros Hor. + (* unfold Rnthroot. *) + destruct (Rcase_abs x), (Rcase_abs y); [lra|..]. + - rewrite Rnthroot_nonpos. + 2: { + replace 0 with (0 * y) by lra. + apply Rmult_le_compat_r; lra. + } + rewrite Rnthroot_nonpos; lra. + - rewrite Rnthroot_nonpos. + 2: { + replace 0 with (x * 0) by lra. + apply Rmult_le_compat_l; lra. + } + rewrite (Rnthroot_nonpos n y); lra. + - pose proof (Rpow_Sn_nthroot n (x * y) + ltac:(replace 0 with (x*0); + [apply Rmult_le_compat_l|];lra)). + rewrite <- (Rpow_Sn_nthroot n x), <- (Rpow_Sn_nthroot n y) + in H at 2 by lra. + rewrite <- Rpow_mult_distr in H. + apply Rpow_inj_pos in H. + + easy. + + apply Rnthroot_positivity. + + replace 0 with (0*0); [apply Rmult_le_compat|]; + try apply Rnthroot_positivity; try lra. + + easy. +Qed. + +Lemma Rnthroot_1_r n : Rnthroot n 1 = 1. +Proof. + pose proof (Rpow_Sn_nthroot n 1 ltac:(lra)). + rewrite <- (Rpow_1_l (S n)) in H at 2. + apply Rpow_inj_pos in H; + easy + apply Rnthroot_positivity + lra. +Qed. + +Lemma Rnthroot_pow n m x : 0 <= x -> + Rnthroot n (x ^ m) = Rnthroot n x ^ m. +Proof. + intros Hx. + induction m. + - now rewrite Rnthroot_1_r. + - simpl. + rewrite Rnthroot_mult_distr by auto. + now rewrite IHm. +Qed. + +Lemma Rnthroot_nthroot n m x : + Rnthroot n (Rnthroot m x) = Rnthroot (n * m + n + m) x. +Proof. + destruct (Rlt_le_dec x 0); + [now rewrite !(Rnthroot_nonpos _ x ltac:(lra)), Rnthroot_0_r|]. + apply (Rpow_inj_pos (S n * S m)); [apply Rnthroot_positivity..|easy|]. + replace (S n * S m)%nat with (S (n * m + n + m)) at 2 by lia. + rewrite Rpow_Sn_nthroot by easy. + rewrite pow_mult. + rewrite 2!Rpow_Sn_nthroot by (easy + apply Rnthroot_positivity). + easy. +Qed. + +End Rnthroot_def. diff --git a/Rings.v b/Rings.v index 32465d6..46d37b7 100644 --- a/Rings.v +++ b/Rings.v @@ -60,6 +60,13 @@ Proof. intros. - lca. Qed. +Lemma Cpow_pos_0 : forall (p : positive), Cpow_pos C0 p = C0. +Proof. + intros p. + rewrite Cpow_pos_to_nat. + apply Cpow_0_l. + lia. +Qed. Lemma Cpow_pos_nonzero : forall (c : C) (p : positive), c <> C0 -> Cpow_pos c p <> C0. Proof. intros. @@ -103,30 +110,30 @@ Proof. intros. Qed. Lemma Cpow_pos_pred_double : forall (c : C) (p : positive), - c <> 0 -> Cpow_pos c (Pos.pred_double p) = (/ c) * (Cpow_pos c p * Cpow_pos c p). -Proof. intros. - induction p; simpl; auto. - - repeat rewrite Cmult_assoc. - rewrite Cinv_l; try lca; auto. - - rewrite IHp. - repeat rewrite Cmult_assoc. - rewrite Cinv_r; try lca; auto. - - rewrite Cmult_assoc, Cinv_l; try lca; auto. +Proof. + intros c p. + destruct (Ceq_dec c 0); [subst; rewrite Cpow_pos_0; lca|]. + induction p; simpl; auto. + - repeat rewrite Cmult_assoc. + rewrite Cinv_l; try lca; auto. + - rewrite IHp. + repeat rewrite Cmult_assoc. + rewrite Cinv_r; try lca; auto. + - rewrite Cmult_assoc, Cinv_l; try lca; auto. Qed. Lemma Cpow_pos_inv : forall (c : C) (p : positive), - c <> 0 -> Cpow_pos (/ c) p = / (Cpow_pos c p). -Proof. intros. - induction p; simpl. - - rewrite IHp. - repeat rewrite Cinv_mult_distr; auto. - 2 : apply Cmult_neq_0. - all : try apply Cpow_pos_nonzero; auto. - - rewrite IHp. - rewrite Cinv_mult_distr; try apply Cpow_pos_nonzero; auto. - - easy. +Proof. + intros c p. + destruct (Ceq_dec c 0); [subst; rewrite Cinv_0, Cpow_pos_0; lca|]. + induction p; simpl. + - rewrite IHp. + repeat rewrite Cinv_mult_distr; auto. + - rewrite IHp. + rewrite Cinv_mult_distr; try apply Cpow_pos_nonzero; auto. + - easy. Qed. Lemma Cpow_pos_real : forall (c : C) (p : positive), @@ -171,21 +178,30 @@ Proof. intros. apply Cpow_pos_nonzero; easy. Qed. +Lemma Cpow_int_0_l : forall (z : Z), z <> 0%Z -> C0 ^^ z = C0. +Proof. + intros z Hz. + destruct z; [easy|..]; + simpl. + - apply Cpow_pos_0. + - rewrite Cpow_pos_0, Cinv_0. + reflexivity. +Qed. Lemma Cpow_int_add_1 : forall (c : C) (z : Z), c <> C0 -> c ^^ (1 + z) = c * c^^z. -Proof. intros. - destruct z; try lca. - - destruct p; simpl; try lca. - rewrite Cpow_pos_succ; lca. - - destruct p; simpl. - + rewrite <- Cmult_assoc, (Cinv_mult_distr c); auto. - rewrite Cmult_assoc, Cinv_r; try lca; auto. - apply Cmult_neq_0; apply Cpow_pos_nonzero; easy. - + rewrite Cpow_pos_pred_double, Cinv_mult_distr, Cinv_inv; auto. - apply nonzero_div_nonzero; auto. - apply Cmult_neq_0; apply Cpow_pos_nonzero; easy. - + rewrite Cinv_r; easy. +Proof. + intros. + destruct z; try lca. + - destruct p; simpl; try lca. + rewrite Cpow_pos_succ; lca. + - destruct p; simpl. + + rewrite <- Cmult_assoc, (Cinv_mult_distr c); auto. + destruct (Ceq_dec c 0); + [subst; now rewrite Cpow_pos_0, !Cmult_0_l, Cinv_0|]. + rewrite Cmult_assoc, Cinv_r; try lca; auto. + + rewrite Cpow_pos_pred_double, Cinv_mult_distr, Cinv_inv; auto. + + rewrite Cinv_r; easy. Qed. Lemma Cpow_int_minus_1 : forall (c : C) (z : Z), @@ -200,9 +216,7 @@ Proof. intros. - destruct p; simpl; try lca. + rewrite Cpow_pos_succ, <- Cinv_mult_distr; auto. apply f_equal; lca. - repeat apply Cmult_neq_0; try apply Cpow_pos_nonzero; auto. + rewrite <- Cmult_assoc, Cinv_mult_distr; auto. - apply Cmult_neq_0; apply Cpow_pos_nonzero; auto. + rewrite Cinv_mult_distr; easy. Qed. @@ -242,7 +256,6 @@ Proof. intros. all : rewrite Cpow_pos_mult_r; try lca. all : rewrite Cpow_pos_inv; try apply Cpow_pos_nonzero; auto. rewrite Cinv_inv; auto. - do 2 apply Cpow_pos_nonzero; easy. Qed. Lemma Cpow_int_mult_l : forall (c1 c2 : C) (z : Z), diff --git a/RowColOps.v b/RowColOps.v index 2127df7..bc87f9c 100644 --- a/RowColOps.v +++ b/RowColOps.v @@ -2529,6 +2529,32 @@ Proof. intros. rewrite <- matrix_by_basis_adjoint, <- matrix_by_basis; auto. Qed. +Lemma mat_equiv_of_equiv_on_ei : forall {n m} (A B : Matrix n m), + (forall k, (k < m)%nat -> A × e_i k ≡ B × e_i k) -> + A ≡ B. +Proof. + intros n m A B Heq. + intros i j Hi Hj. + specialize (Heq j Hj). + rewrite <- 2!(matrix_by_basis _ _ Hj) in Heq. + specialize (Heq i O Hi ltac:(lia)). + unfold get_col in Heq. + rewrite Nat.eqb_refl in Heq. + easy. +Qed. + +Lemma eq_of_eq_on_ei : forall {n m} (A B : Matrix n m), + WF_Matrix A -> WF_Matrix B -> + (forall k, (k < m)%nat -> A × e_i k = B × e_i k) -> + A = B. +Proof. + intros n m A B HA HB HAB. + apply mat_equiv_eq; [easy..|]. + apply mat_equiv_of_equiv_on_ei. + intros k Hk. + now rewrite HAB by easy. +Qed. + (** * Lemmas related to 1pad *) diff --git a/Summation.v b/Summation.v index bc6d63c..499774f 100644 --- a/Summation.v +++ b/Summation.v @@ -151,6 +151,13 @@ Proof. intros. easy. Qed. +Lemma Gopp_0 : forall {R} `{Group R}, - 0 = 0. +Proof. + symmetry. + apply Gopp_unique_r. + apply Gplus_0_l. +Qed. + Lemma Gopp_involutive : forall {G} `{Group G} (g : G), - (- g) = g. Proof. intros. @@ -188,6 +195,13 @@ Proof. intros. rewrite Gplus_0_r, <- Gmult_plus_distr_l, Gplus_0_r; easy. Qed. +Lemma Gmult_if : forall {R} `{Ring R} (b c : bool) (r s : R), + (if b then r else 0) * (if c then s else 0) = + if b && c then r * s else 0. +Proof. + intros. + destruct b, c; now rewrite ?Gmult_0_l, ?Gmult_0_r. +Qed. Lemma Gopp_neg_1 : forall {R} `{Ring R} (r : R), -1%G * r = -r. Proof. intros. @@ -197,7 +211,6 @@ Proof. intros. rewrite <- Gmult_plus_distr_r, Gopp_r, Gmult_0_l; easy. Qed. - Lemma Ginv_l : forall {F} `{Field F} (f : F), f <> 0 -> (Ginv f) * f = 1. Proof. intros; rewrite Gmult_comm; apply Ginv_r; easy. Qed. @@ -372,6 +385,14 @@ Proof. induction n; try easy. lia. Qed. +Lemma big_sum_opp : forall {G} `{Comm_Group G} (f : nat -> G) n, + - big_sum f n = big_sum (fun k => - f k) n. +Proof. + induction n; simpl. + - apply Gopp_0. + - rewrite Gopp_plus_distr. + now rewrite Gplus_comm, IHn. +Qed. Lemma big_sum_plus : forall {G} `{Comm_Group G} f g n, big_sum (fun x => f x + g x) n = big_sum f n + big_sum g n. @@ -388,6 +409,61 @@ Proof. rewrite Gplus_comm; easy. Qed. +Lemma big_sum_if_or : forall {G} `{Comm_Group G} + (ifl ifr : nat -> bool) + (f : nat -> G) (n : nat), + big_sum (fun k => if ifl k || ifr k then f k else 0) n = + big_sum (fun k => if ifl k then f k else 0) n + + big_sum (fun k => if ifr k then f k else 0) n - + big_sum (fun k => if ifl k && ifr k then f k else 0) n. +Proof. + intros. + unfold Gminus. + rewrite big_sum_opp. + rewrite <- 2!big_sum_plus. + apply big_sum_eq_bounded. + intros k Hk. + destruct (ifl k), (ifr k); simpl; + rewrite <- Gplus_assoc, ?Gopp_r, + ?Gopp_0, ?Gplus_0_r, ?Gplus_0_l; easy. +Qed. + +Lemma big_sum_if_eq : forall {G} `{Monoid G} (f : nat -> G) n k, + big_sum (fun x => if x =? k then f x else 0) n = + if k G) n k, + big_sum (fun x => if k =? x then f x else 0) n = + if k G) n k l, k <> l -> + big_sum (fun x => if (x =? k) || (x =? l) then f x else 0) n = + (if k c ⋅ f x) n. Proof. @@ -460,6 +536,27 @@ Proof. + rewrite Gplus_assoc, IHn. simpl; reflexivity. Qed. + +Lemma big_sum_split : forall {G} `{Monoid G} n i (v : nat -> G) (Hi : (i < n)), + big_sum v n = (big_sum v i) + v i + + (big_sum (fun k => v (k + i + 1)%nat) (n - 1 - i)). +Proof. + intros. + induction n; [lia|]. + bdestruct (i =? n). + - subst. + replace (S n - 1 - n)%nat with O by lia. + rewrite <- big_sum_extend_r. + simpl. + symmetry. + apply Gplus_0_r. + - specialize (IHn ltac:(lia)). + replace (S n - 1 - i)%nat with (S (n - 1 - i))%nat by lia. + rewrite <- !big_sum_extend_r. + rewrite IHn. + replace (n - 1 - i + i + 1)%nat with n by lia. + now rewrite Gplus_assoc. +Qed. Lemma big_sum_unique : forall {G} `{Monoid G} k (f : nat -> G) n, (exists x, (x < n)%nat /\ f x = k /\ (forall x', x' < n -> x <> x' -> f x' = 0)) -> @@ -587,6 +684,19 @@ Proof. induction m as [| m']. easy. Qed. +Lemma big_sum_product_div_mod_split : forall {G} `{Monoid G} n m (f : nat -> G), + big_sum f (n * m) = + big_sum (fun i => big_sum (fun j => f (j + i * n)%nat) n) m. +Proof. + intros. + rewrite big_sum_double_sum. + apply big_sum_eq_bounded. + intros k Hk. + f_equal. + rewrite (Nat.div_mod_eq k n) at 1. + lia. +Qed. + Local Close Scope nat_scope. Lemma big_sum_extend_double : forall {G} `{Ring G} (f : nat -> nat -> G) (n m : nat), diff --git a/VectorStates.v b/VectorStates.v index c1f15ff..6f4f871 100644 --- a/VectorStates.v +++ b/VectorStates.v @@ -1,10 +1,11 @@ Require Export Pad. Require Export CauchySchwarz. -Require Import Bits. +Require Import PermutationInstances. +Require Export Bits. (* This file provides abstractions for describing quantum states as vectors. - f_to_vec describes classical states as boolean functions - - basis_vector describes classiacal states as natural numbers + - basis_vector describes classical states as natural numbers - vsum describes superposition states - vkron describes states as the tensor product of qubit states @@ -33,6 +34,24 @@ Proof. reflexivity. Qed. Lemma ket1_equiv : ∣1⟩ = ket 1. Proof. reflexivity. Qed. +Lemma plus_equiv : ∣+⟩ = ∣ + ⟩. +Proof. lma'. Qed. + +Lemma minus_equiv : ∣-⟩ = ∣ - ⟩. +Proof. lma'. Qed. + +Lemma bra0_eqb : ⟨0∣ = (fun i j => if (i =? 0) && (j =? 0) then C1 else C0). +Proof. lma'. intros i j []; Modulus.bdestructΩ'. Qed. + +Lemma bra1_eqb : ⟨1∣ = (fun i j => if (i =? 0) && (j =? 1) then C1 else C0). +Proof. lma'. intros i j []; Modulus.bdestructΩ'. Qed. + +Lemma ket0_eqb : ∣0⟩ = (fun i j => if (i =? 0) && (j =? 0) then C1 else C0). +Proof. lma'. intros i j []; Modulus.bdestructΩ'. Qed. + +Lemma ket1_eqb : ∣1⟩ = (fun i j => if (i =? 1) && (j =? 0) then C1 else C0). +Proof. lma'. intros i j []; Modulus.bdestructΩ'. Qed. + Lemma bra0ket0 : bra 0 × ket 0 = I 1. Proof. lma'. Qed. @@ -45,6 +64,22 @@ Proof. lma'. Qed. Lemma bra1ket1 : bra 1 × ket 1 = I 1. Proof. lma'. Qed. +Lemma bra0ket_eqb i : bra 0 × ket i = + if i =? 0 then I 1 else Zero. +Proof. + destruct i; simpl. + - apply bra0ket0. + - apply bra0ket1. +Qed. + +Lemma bra1ket_eqb i : bra 1 × ket i = + if i =? 0 then Zero else I 1. +Proof. + destruct i; simpl. + - apply bra1ket0. + - apply bra1ket1. +Qed. + (* Hadamard properties *) Lemma H0_spec : hadamard × ∣ 0 ⟩ = ∣ + ⟩. Proof. lma'. Qed. @@ -53,10 +88,12 @@ Lemma H1_spec : hadamard × ∣ 1 ⟩ = ∣ - ⟩. Proof. lma'. Qed. Lemma Hplus_spec : hadamard × ∣ + ⟩ = ∣ 0 ⟩. -Proof. solve_matrix. Qed. +Proof. solve_matrix_fast_with + (autounfold with U_db) (try lca; C_field; lca). Qed. Lemma Hminus_spec : hadamard × ∣ - ⟩ = ∣ 1 ⟩. -Proof. solve_matrix. Qed. +Proof. solve_matrix_fast_with + (autounfold with U_db) (try lca; C_field; lca). Qed. Local Open Scope nat_scope. @@ -65,20 +102,18 @@ Local Open Scope nat_scope. Lemma H0_kron_n_spec : forall n, n ⨂ hadamard × n ⨂ ∣0⟩ = n ⨂ ∣+⟩. Proof. - intros. - induction n; simpl. - - Msimpl_light. reflexivity. - - replace (2^n + (2^n + 0)) with (2^n * 2) by lia. - replace (1^n + 0) with (1*1) by (rewrite Nat.pow_1_l, Nat.add_0_r; lia). - rewrite Nat.pow_1_l. - rewrite kron_mixed_product. - rewrite <- IHn. - apply f_equal_gen; try reflexivity. - lma'. + intros n. + rewrite kron_n_mult. + rewrite ket0_equiv, plus_equiv. + now rewrite H0_spec. Qed. Local Close Scope nat_scope. +Definition b2R (b : bool) : R := if b then 1%R else 0%R. +Local Coercion b2R : bool >-> R. +Local Coercion Nat.b2n : bool >-> nat. + (* X properties *) Lemma X0_spec : σx × ∣ 0 ⟩ = ∣ 1 ⟩. Proof. lma'. Qed. @@ -86,6 +121,13 @@ Proof. lma'. Qed. Lemma X1_spec : σx × ∣ 1 ⟩ = ∣ 0 ⟩. Proof. lma'. Qed. +Lemma X_specb (b : bool) : σx × ∣ b ⟩ = ∣ negb b ⟩. +Proof. + destruct b. + - apply X1_spec. + - apply X0_spec. +Qed. + (* Y properties *) Lemma Y0_spec : σy × ∣ 0 ⟩ = Ci .* ∣ 1 ⟩. Proof. lma'. Qed. @@ -93,6 +135,16 @@ Proof. lma'. Qed. Lemma Y1_spec : σy × ∣ 1 ⟩ = -Ci .* ∣ 0 ⟩. Proof. lma'. Qed. +Lemma Y_specb (b : bool) : + σy × ∣ b ⟩ = (-1)^b * Ci .* ∣ negb b ⟩. +Proof. + destruct b. + - simpl. rewrite Y1_spec. + f_equal; lca. + - simpl. rewrite Y0_spec. + f_equal; lca. +Qed. + (* Z properties *) Lemma Z0_spec : σz × ∣ 0 ⟩ = ∣ 0 ⟩. Proof. lma'. Qed. @@ -100,6 +152,36 @@ Proof. lma'. Qed. Lemma Z1_spec : σz × ∣ 1 ⟩ = -1 .* ∣ 1 ⟩. Proof. lma'. Qed. +Lemma Z_specb (b : bool) : + σz × ∣ b ⟩ = (-1)^b .* ∣ b ⟩. +Proof. + destruct b. + - simpl. rewrite Z1_spec. + now Csimpl. + - simpl. rewrite Z0_spec. + now rewrite Mscale_1_l. +Qed. + +Lemma Z_bspec (b : bool) : + bra b × σz = (-1)^b .* bra b. +Proof. + destruct b. + - simpl. lma'. + - simpl. lma'. +Qed. + +Lemma MmultZ1 : σz × ∣1⟩ = - C1 .* ∣1⟩. +Proof. rewrite ket1_equiv, Z1_spec. f_equal; lca. Qed. + +Lemma MmultZ0 : σz × ∣0⟩ = ∣0⟩. +Proof. rewrite ket0_equiv, Z0_spec. reflexivity. Qed. + +Lemma Mmult1Z : ⟨1∣ × σz = - C1 .* ⟨1∣. +Proof. lma'. Qed. + +Lemma Mmult0Z : ⟨0∣ × σz = ⟨0∣. +Proof. lma'. Qed. + (* phase shift properties *) Lemma phase0_spec : forall ϕ, phase_shift ϕ × ket 0 = ket 0. Proof. intros. lma'. Qed. @@ -107,33 +189,24 @@ Proof. intros. lma'. Qed. Lemma phase1_spec : forall ϕ, phase_shift ϕ × ket 1 = Cexp ϕ .* ket 1. Proof. intros. lma'. Qed. -Definition b2R (b : bool) : R := if b then 1%R else 0%R. -Local Coercion b2R : bool >-> R. -Local Coercion Nat.b2n : bool >-> nat. - Lemma phase_shift_on_ket : forall (θ : R) (b : bool), phase_shift θ × ∣ b ⟩ = (Cexp (b * θ)) .* ∣ b ⟩. Proof. intros. - destruct b; solve_matrix; autorewrite with R_db. - reflexivity. - rewrite Cexp_0; reflexivity. + destruct b; simpl; + [rewrite Rmult_1_l | rewrite Rmult_0_l, Cexp_0]; + solve_matrix_fast. Qed. Lemma hadamard_on_ket : forall (b : bool), hadamard × ∣ b ⟩ = /√2 .* (∣ 0 ⟩ .+ (-1)^b .* ∣ 1 ⟩). Proof. intros. - destruct b; solve_matrix; autorewrite with R_db Cexp_db; lca. + destruct b; solve_matrix_fast. Qed. (* CNOT properties *) -Lemma CNOT_spec : forall (x y : nat), (x < 2)%nat -> (y < 2)%nat -> cnot × ∣ x,y ⟩ = ∣ x, (x + y) mod 2 ⟩. -Proof. - intros; destruct x as [| [|x]], y as [| [|y]]; try lia; lma'. -Qed. - Lemma CNOT00_spec : cnot × ∣ 0,0 ⟩ = ∣ 0,0 ⟩. Proof. lma'. Qed. @@ -146,10 +219,22 @@ Proof. lma'. Qed. Lemma CNOT11_spec : cnot × ∣ 1,1 ⟩ = ∣ 1,0 ⟩. Proof. lma'. Qed. +Lemma CNOT_spec : forall (x y : nat), (x < 2)%nat -> (y < 2)%nat -> + cnot × ∣ x,y ⟩ = ∣ x, (x + y) mod 2 ⟩. +Proof. + by_cell_no_intros. + - apply CNOT00_spec. + - apply CNOT01_spec. + - apply CNOT10_spec. + - apply CNOT11_spec. +Qed. + + + (* SWAP properties *) Lemma SWAP_spec : forall x y, swap × ∣ x,y ⟩ = ∣ y,x ⟩. -Proof. intros. destruct x,y; lma'. Qed. +Proof. intros. apply swap_spec; auto_wf. Qed. (* Automation *) @@ -174,28 +259,28 @@ Proof. destruct n; reflexivity. Qed. (* TODO: add transpose and adjoint lemmas to ket_db? *) Lemma ket0_transpose_bra0 : (ket 0) ⊤ = bra 0. -Proof. solve_matrix. Qed. +Proof. lma'. Qed. Lemma ket1_transpose_bra1 : (ket 1) ⊤ = bra 1. -Proof. solve_matrix. Qed. +Proof. lma'. Qed. Lemma bra0_transpose_ket0 : (bra 0) ⊤ = ket 0. -Proof. solve_matrix. Qed. +Proof. lma'. Qed. Lemma bra1_transpose_ket1 : (bra 1) ⊤ = ket 1. -Proof. solve_matrix. Qed. +Proof. lma'. Qed. Lemma bra1_adjoint_ket1 : (bra 1) † = ket 1. -Proof. solve_matrix. Qed. +Proof. lma'. Qed. Lemma ket1_adjoint_bra1 : (ket 1) † = bra 1. -Proof. solve_matrix. Qed. +Proof. lma'. Qed. Lemma bra0_adjoint_ket0 : (bra 0) † = ket 0. -Proof. solve_matrix. Qed. +Proof. lma'. Qed. Lemma ket0_adjoint_bra0 : (ket 0) † = bra 0. -Proof. solve_matrix. Qed. +Proof. lma'. Qed. (* Examples using ket_db *) Lemma XYZ0 : -Ci .* σx × σy × σz × ∣ 0 ⟩ = ∣ 0 ⟩. @@ -205,8 +290,8 @@ Lemma XYZ1 : -Ci .* σx × σy × σz × ∣ 1 ⟩ = ∣ 1 ⟩. Proof. autorewrite with ket_db C_db. replace (Ci * -1 * Ci) with (RtoC 1) by lca. - rewrite Mscale_1_l; reflexivity. - Qed. + apply Mscale_1_l. +Qed. (*******************************) @@ -357,23 +442,15 @@ Lemma split_basis_vector : forall m n x y, = basis_vector (2 ^ m) x ⊗ basis_vector (2 ^ n) y. Proof. intros m n x y Hx Hy. + apply mat_equiv_eq; + [apply basis_vector_WF; Modulus.show_moddy_lt|auto_wf|]. unfold kron, basis_vector. - solve_matrix. - bdestruct (y0 =? 0). - - repeat rewrite andb_true_r. - assert (2^n > 0)%nat. - { assert (0 < 2^n)%nat by (apply pow_positive; lia). lia. - } - specialize (divmod_decomp x0 x y (2^n)%nat H0 Hy) as G. - bdestruct (x0 =? x * 2 ^ n + y). - + apply G in H1. destruct H1. - rewrite H1, H2. do 2 rewrite Nat.eqb_refl. lca. - + bdestruct (x0 / 2 ^ n =? x); bdestruct (x0 mod 2 ^ n =? y); try lca. - assert ((x0 / 2 ^ n)%nat = x /\ x0 mod 2 ^ n = y) by easy. - apply G in H4. - easy. - - repeat rewrite andb_false_r. - lca. + intros i j Hi Hj. + replace j with 0 by lia. + Modulus.simpl_bools. + rewrite Cmult_if_if_1_l. + rewrite Modulus.eqb_comb_iff_div_mod_eqb by easy. + now rewrite andb_comm. Qed. (* rewrite f_to_vec as basis_vector *) @@ -383,32 +460,26 @@ Proof. intros. induction n. - unfold funbool_to_nat; simpl. - unfold basis_vector. - unfold I. - prep_matrix_equality. - bdestruct (x =? 0); bdestruct (x =? y); subst; simpl; trivial. - bdestruct_all; easy. - bdestructΩ (y WF_Matrix B -> (forall f, A × (f_to_vec dim f) = B × (f_to_vec dim f)) -> A = B. Proof. - intros dim A B WFA WFB H. + intros n dim A B WFA WFB H. apply equal_on_basis_vectors_implies_equal; trivial. intros k Lk. rewrite basis_f_to_vec_alt; auto. Qed. +Lemma equal_on_conj_basis_states_implies_equal {n m} + (A B : Matrix (2 ^ n) (2 ^ m)) : WF_Matrix A -> WF_Matrix B -> + (forall f g, (f_to_vec n g) ⊤ × (A × f_to_vec m f) = + (f_to_vec n g) ⊤ × (B × f_to_vec m f)) -> A = B. +Proof. + intros HA HB HAB. + apply equal_on_basis_states_implies_equal; [auto..|]. + intros f. + apply transpose_matrices. + apply equal_on_basis_states_implies_equal; [auto_wf..|]. + intros g. + apply transpose_matrices. + rewrite Mmult_transpose, transpose_involutive, HAB. + rewrite Mmult_transpose, transpose_involutive. + reflexivity. +Qed. + Lemma f_to_vec_update_oob : forall (n : nat) (f : nat -> bool) (i : nat) (b : bool), n <= i -> f_to_vec n (update f i b) = f_to_vec n f. Proof. @@ -526,6 +614,43 @@ Proof. destruct (f i); simpl; autorewrite with ket_db; reflexivity. Qed. +Lemma f_to_vec_σy : forall (n i : nat) (f : nat -> bool), + i < n -> + (pad_u n i σy) × (f_to_vec n f) = + (-1)%R^(f i) * Ci .* f_to_vec n (update f i (¬ f i)). +Proof. + intros n i f Hi. + unfold pad_u, pad. + rewrite (f_to_vec_split 0 n i f Hi). + repad. + replace (i + 1 + x - 1 - i) with x by lia. + Msimpl. + rewrite Y_specb. + distribute_scale. + rewrite (f_to_vec_split 0 (i + 1 + x) i) by lia. + rewrite f_to_vec_update_oob by lia. + rewrite f_to_vec_shift_update_oob by lia. + rewrite update_index_eq. + replace (i + 1 + x - 1 - i) with x by lia. + easy. +Qed. + +Lemma f_to_vec_σz : forall (n i : nat) (f : nat -> bool), + i < n -> + (pad_u n i σz) × (f_to_vec n f) = + (-1)%R^(f i) .* f_to_vec n f. +Proof. + intros n i f Hi. + unfold pad_u, pad. + rewrite (f_to_vec_split 0 n i f Hi). + repad. + replace (i + 1 + x - 1 - i) with x by lia. + Msimpl. + rewrite Z_specb. + distribute_scale. + reflexivity. +Qed. + Lemma f_to_vec_cnot : forall (n i j : nat) (f : nat -> bool), i < n -> j < n -> i <> j -> (pad_ctrl n i j σx) × (f_to_vec n f) = f_to_vec n (update f j (f j ⊕ f i)). @@ -543,11 +668,35 @@ Proof. repeat rewrite shift_simplify. replace (d + (i + 1)) with (i + 1 + d) by lia. rewrite update_index_eq. - distribute_plus. restore_dims. - repeat rewrite <- kron_assoc by auto with wf_db. - destruct (f i); destruct (f (i + 1 + d)); simpl; Msimpl. - all: autorewrite with ket_db; reflexivity. + rewrite <- !kron_assoc by auto_wf. + restore_dims. + rewrite kron_mixed_product' by lia. + rewrite Mmult_1_l by auto_wf. + restore_dims. + rewrite (kron_assoc (f_to_vec i f)) by auto_wf. + restore_dims. + rewrite !(kron_assoc (f_to_vec i f)) by auto_wf. + restore_dims. + f_equal. + rewrite kron_mixed_product, Mmult_1_l by auto_wf. + f_equal. + simpl. + restore_dims. + distribute_plus. + rewrite !kron_mixed_product. + rewrite 2!Mmult_1_l by auto_wf. + symmetry. + rewrite <- (Mmult_1_l _ _ (∣ f i ⟩)) at 1 by auto_wf. + rewrite <- Mplus10. + distribute_plus. + rewrite !(Mmult_assoc _ _ (∣ f i ⟩)). + rewrite bra1_equiv, bra1ket_eqb. + rewrite bra0_equiv, bra0ket_eqb. + destruct (f i); simpl; rewrite Mmult_0_r, !kron_0_l. + + rewrite xorb_true_r. + now rewrite X_specb. + + now rewrite xorb_false_r. - repeat rewrite (f_to_vec_split 0 (j + (1 + d + 1) + x0) j); try lia. rewrite f_to_vec_update_oob by lia. rewrite update_index_eq. @@ -558,11 +707,29 @@ Proof. replace (d + (j + 1)) with (j + 1 + d) by lia. rewrite update_index_neq by lia. replace (j + (1 + d + 1) + x0 - 1 - j - 1 - d) with x0 by lia. - distribute_plus. restore_dims. - repeat rewrite <- kron_assoc by auto with wf_db. - destruct (f j); destruct (f (j + 1 + d)); simpl; Msimpl. - all: autorewrite with ket_db; reflexivity. + rewrite kron_assoc, !(kron_assoc (f_to_vec j f)) by auto_wf. + restore_dims. + rewrite kron_mixed_product' by lia. + f_equal; [lia | apply Mmult_1_l; auto_wf|]. + rewrite <- 4!kron_assoc by auto_wf. + restore_dims. + rewrite kron_mixed_product. + f_equal; [| apply Mmult_1_l; auto_wf]. + distribute_plus. + rewrite !kron_mixed_product, Mmult_1_l by auto_wf. + rewrite !Mmult_assoc. + rewrite Mmult_1_l by auto_wf. + rewrite bra1_equiv, bra1ket_eqb. + rewrite bra0_equiv, bra0ket_eqb. + destruct (f (j + 1 + d)); simpl; rewrite Mmult_0_r, !kron_0_r. + + rewrite Mplus_0_r. + rewrite xorb_true_r. + rewrite Mmult_1_r, ket1_equiv by auto_wf. + now rewrite X_specb. + + rewrite Mplus_0_l. + rewrite Mmult_1_r, ket0_equiv by auto_wf. + now rewrite xorb_false_r. Qed. Lemma f_to_vec_swap : forall (n i j : nat) (f : nat -> bool), @@ -581,7 +748,7 @@ Proof. rewrite update_twice_neq by auto. rewrite update_twice_eq. reflexivity. - all: destruct (f i); destruct (f j); auto. + all: destruct (f i); destruct (f j); reflexivity. Qed. Lemma f_to_vec_phase_shift : forall (n i : nat) (θ : R) (f : nat -> bool), @@ -634,6 +801,223 @@ Local Close Scope R_scope. #[global] Hint Rewrite f_to_vec_cnot f_to_vec_σx f_to_vec_phase_shift using lia : f_to_vec_db. #[global] Hint Rewrite (@update_index_eq bool) (@update_index_neq bool) (@update_twice_eq bool) (@update_same bool) using lia : f_to_vec_db. +Import Modulus. + +Lemma kron_f_to_vec {n m p q} (A : Matrix (2^n) (2^m)) + (B : Matrix (2^p) (2^q)) f : + @mat_equiv _ 1 (A ⊗ B × f_to_vec (m + q) f) + ((A × f_to_vec m f (* : Matrix _ 1 *)) ⊗ + (B × f_to_vec q (fun k => f (m + k)) (* : Matrix _ 1) *))). +Proof. + rewrite <- kron_mixed_product. + rewrite f_to_vec_merge. + Morphisms.f_equiv. + apply f_to_vec_eq. + intros; bdestructΩ'; f_equal; lia. +Qed. + +Lemma kron_f_to_vec_eq {n m p q : nat} (A : Matrix (2^n) (2^m)) + (B : Matrix (2^p) (2^q)) (f : nat -> bool) : WF_Matrix A -> WF_Matrix B -> + A ⊗ B × f_to_vec (m + q) f + = A × f_to_vec m f ⊗ (B × f_to_vec q (fun k : nat => f (m + k))). +Proof. + intros. + prep_matrix_equivalence. + apply kron_f_to_vec. +Qed. + +Lemma f_to_vec_split' n m f : + mat_equiv (f_to_vec (n + m) f) + (f_to_vec n f ⊗ f_to_vec m (fun k => f (n + k))). +Proof. + intros i j Hi Hj. + rewrite f_to_vec_merge. + erewrite f_to_vec_eq; [reflexivity|]. + intros; simpl; bdestructΩ'; f_equal; lia. +Qed. + +Lemma f_to_vec_split'_eq n m f : + (f_to_vec (n + m) f) = + (f_to_vec n f ⊗ f_to_vec m (fun k => f (n + k))). +Proof. + apply mat_equiv_eq; [..|apply f_to_vec_split']; auto with wf_db. +Qed. + +Lemma f_to_vec_1_eq f : + f_to_vec 1 f = if f 0 then ∣1⟩ else ∣0⟩. +Proof. + cbn. + unfold ket. + rewrite kron_1_l by (destruct (f 0); auto with wf_db). + now destruct (f 0). +Qed. + +Lemma f_to_vec_1_mult_r f (A : Matrix (2^1) (2^1)) : + A × f_to_vec 1 f = (fun x j => if j =? 0 then A x (Nat.b2n (f 0)) else 0%R). +Proof. + cbn. + rewrite kron_1_l by auto with wf_db. + apply functional_extensionality; intros i. + apply functional_extensionality; intros j. + unfold Mmult. + simpl. + destruct (f 0); + unfold ket; + simpl; + now destruct j; simpl; Csimpl. +Qed. + +Lemma f_to_vec_1_mult_r_decomp f (A : Matrix (2^1) (2^1)) : + A × f_to_vec 1 f ≡ + A 0 (Nat.b2n (f 0)) .* ∣0⟩ .+ + A 1 (Nat.b2n (f 0)) .* ∣1⟩. +Proof. + rewrite f_to_vec_1_mult_r. + intros i j Hi Hj. + replace j with 0 by lia. + simpl. + autounfold with U_db. + do 2 (try destruct i); [..| simpl in *; lia]; + now Csimpl. +Qed. + +Lemma f_to_vec_1_mult_r_decomp_eq f (A : Matrix (2^1) (2^1)) : + WF_Matrix A -> + A × f_to_vec 1 f = + A 0 (Nat.b2n (f 0)) .* ∣0⟩ .+ + A 1 (Nat.b2n (f 0)) .* ∣1⟩. +Proof. + intros. + apply mat_equiv_eq; auto with wf_db. + apply f_to_vec_1_mult_r_decomp. +Qed. + +Lemma qubit0_f_to_vec : ∣0⟩ = f_to_vec 1 (fun x => false). +Proof. now rewrite f_to_vec_1_eq. Qed. + +Lemma qubit1_f_to_vec : ∣1⟩ = f_to_vec 1 (fun x => x =? 0). +Proof. now rewrite f_to_vec_1_eq. Qed. + +Lemma ket_f_to_vec b : ∣ Nat.b2n b ⟩ = f_to_vec 1 (fun x => b). +Proof. + destruct b; [apply qubit1_f_to_vec | apply qubit0_f_to_vec]. +Qed. + +Lemma f_to_vec_1_mult_r_decomp_eq' f (A : Matrix (2^1) (2^1)) : + WF_Matrix A -> + A × f_to_vec 1 f = + A 0 (Nat.b2n (f 0)) .* f_to_vec 1 (fun x => false) .+ + A 1 (Nat.b2n (f 0)) .* f_to_vec 1 (fun x => x=?0). +Proof. + intros. + apply mat_equiv_eq; auto with wf_db. + rewrite f_to_vec_1_mult_r_decomp. + rewrite 2!f_to_vec_1_eq. + easy. +Qed. + +Lemma f_to_vec_1_mult_l_decomp f (A : Matrix (2^1) (2^1)) : + (f_to_vec 1 f) ⊤ × A ≡ + A (Nat.b2n (f 0)) 0 .* (∣0⟩ ⊤) .+ + A (Nat.b2n (f 0)) 1 .* (∣1⟩ ⊤). +Proof. + rewrite <- (transpose_involutive _ _ A). + rewrite <- Mmult_transpose, <- Mscale_trans. + intros i j Hi Hj. + apply (f_to_vec_1_mult_r_decomp f (A ⊤)); easy. +Qed. + +Lemma f_to_vec_1_mult_l_decomp_eq f (A : Matrix (2^1) (2^1)) : + WF_Matrix A -> + (f_to_vec 1 f) ⊤ × A = + A (Nat.b2n (f 0)) 0 .* (∣0⟩ ⊤) .+ + A (Nat.b2n (f 0)) 1 .* (∣1⟩ ⊤). +Proof. + intros. + apply mat_equiv_eq; auto with wf_db. + apply f_to_vec_1_mult_l_decomp. +Qed. + +Lemma f_to_vec_1_mult_l_decomp_eq' f (A : Matrix (2^1) (2^1)) : + WF_Matrix A -> + (f_to_vec 1 f) ⊤ × A = + A (Nat.b2n (f 0)) 0 .* ((f_to_vec 1 (fun x => false)) ⊤) .+ + A (Nat.b2n (f 0)) 1 .* ((f_to_vec 1 (fun x => x =? 0)) ⊤). +Proof. + intros. + apply mat_equiv_eq; auto with wf_db. + rewrite f_to_vec_1_mult_l_decomp_eq by easy. + now rewrite qubit0_f_to_vec, qubit1_f_to_vec. +Qed. + +Lemma basis_trans_basis {n} i j : + ((basis_vector n i) ⊤ × basis_vector n j) 0 0 = + if (i =? j) && (i eqb (f k) (g k)) (seq 0 n)) .* I 1. +Proof. + prep_matrix_equivalence. + intros [] []; [|lia..]; intros _ _. + rewrite 2!basis_f_to_vec. + rewrite basis_trans_basis. + pose proof (funbool_to_nat_bound n f). + simplify_bools_lia_one_kernel. + unfold scale. + cbn. + rewrite Cmult_1_r. + unfold b2R. + rewrite (if_dist _ _ _ RtoC). + apply f_equal_if; [|easy..]. + apply eq_iff_eq_true. + rewrite Nat.eqb_eq, forallb_seq0, <- funbool_to_nat_eq_iff. + now setoid_rewrite eqb_true_iff. +Qed. + +Lemma f_to_vec_transpose_f_to_vec' n f g : + transpose (f_to_vec n f) × f_to_vec n g = + (if funbool_to_nat n f =? funbool_to_nat n g then + C1 else C0) .* I 1. +Proof. + rewrite f_to_vec_transpose_f_to_vec. + f_equal. + unfold b2R. + rewrite (if_dist R C). + apply f_equal_if; [|easy..]. + apply eq_iff_eq_true. + rewrite forallb_seq0, Nat.eqb_eq. + setoid_rewrite eqb_true_iff. + apply funbool_to_nat_eq_iff. +Qed. + +Lemma f_to_vec_transpose_self n f : + transpose (f_to_vec n f) × f_to_vec n f = + I 1. +Proof. + rewrite f_to_vec_transpose_f_to_vec', Nat.eqb_refl. + now Msimpl_light. +Qed. (*******************************) (** Indexed Vector Sum **) @@ -664,7 +1048,6 @@ Proof. bdestruct_all; simpl; lca. Qed. -Local Opaque Nat.mul. Lemma vsum_sum : forall d n (f : nat -> Vector d), big_sum f (2 * n) = big_sum (fun i => f (2 * i)%nat) n .+ big_sum (fun i => f (2 * i + 1)%nat) n. @@ -678,7 +1061,6 @@ Proof. replace (2 * n + 1)%nat with (S (2 * n)) by lia. lma. Qed. -Local Transparent Nat.mul. Lemma vsum_split : forall {d} (n i : nat) (v : nat -> Vector d), (i < n)%nat -> @@ -703,52 +1085,6 @@ Proof. lma. Qed. -Lemma vsum_eq_up_to_fswap : forall {d} n f (v : nat -> Vector d) x y, - (x < n)%nat -> (y < n)%nat -> - big_sum (fun i => v (f i)) n = big_sum (fun i => v (fswap f x y i)) n. -Proof. - intros d n f v x y Hx Hy. - bdestruct (x =? y). - subst. - apply big_sum_eq. - apply functional_extensionality; intros. - unfold fswap. - bdestruct_all; subst; reflexivity. - bdestruct (x Vector 2), (f O) ⊗ vkron n (shift f 1) = vkron (S n) f. Proof. intros n f WF. - induction n. - simpl. Msimpl. reflexivity. + induction n; [simpl; now rewrite kron_1_l, kron_1_r|]. remember (S n) as n'. simpl. rewrite <- IHn; clear IHn. @@ -801,8 +1136,7 @@ Lemma kron_n_f_to_vec : forall n (A : Square 2) f, n ⨂ A × f_to_vec n f = vkron n (fun k => A × ∣ f k ⟩ ). Proof. intros n A f. - induction n; simpl. - Msimpl. reflexivity. + induction n; simpl; [now Msimpl_light|]. restore_dims. rewrite kron_mixed_product. rewrite IHn. @@ -813,9 +1147,7 @@ Lemma Mscale_vkron_distr_r : forall n x (f : nat -> Vector 2), vkron n (fun i => x .* f i) = x ^ n .* vkron n f. Proof. intros n x f. - induction n. - simpl. Msimpl. reflexivity. - simpl. + induction n; simpl; [now Msimpl_light|]. rewrite IHn. distribute_scale. rewrite Cmult_comm. @@ -828,20 +1160,20 @@ Lemma vkron_split : forall n i (f : nat -> Vector 2), vkron n f = (vkron i f) ⊗ f i ⊗ (vkron (n - 1 - i) (shift f (i + 1))). Proof. intros. - induction n; try lia. + induction n; [lia|]. bdestruct (i =? n). - subst. - replace (S n - 1 - n)%nat with O by lia. - simpl. Msimpl. - reflexivity. - assert (i < n)%nat by lia. - specialize (IHn H2). - replace (S n - 1 - i)%nat with (S (n - 1 - i))%nat by lia. - simpl. - rewrite IHn. - unfold shift. - replace (n - 1 - i + (i + 1))%nat with n by lia. - restore_dims; repeat rewrite kron_assoc; auto 100 with wf_db. + - subst. + replace (S n - 1 - n)%nat with O by lia. + simpl. + now rewrite kron_1_r. + - assert (i < n)%nat by lia. + (* specialize (IHn H2). *) + replace (S n - 1 - i)%nat with (S (n - 1 - i))%nat by lia. + simpl. + rewrite IHn by lia. + unfold shift. + replace (n - 1 - i + (i + 1))%nat with n by lia. + restore_dims; repeat rewrite kron_assoc; auto 100 with wf_db. Qed. Lemma vkron_eq : forall n (f f' : nat -> Vector 2), @@ -865,19 +1197,17 @@ Lemma basis_vector_prepend_0 : forall n k, ∣0⟩ ⊗ basis_vector n k = basis_vector (2 * n) k. Proof. intros. - unfold basis_vector; solve_matrix. (* solve_matrix doesn't work? *) - repeat rewrite andb_true_r. - bdestruct (x / n =? 0). - rewrite H1. apply Nat.div_small_iff in H1; auto. - rewrite Nat.mod_small by auto. - destruct (x =? k); lca. - assert (H1' := H1). - rewrite Nat.div_small_iff in H1'; auto. - destruct (x / n)%nat; try lia. - bdestructΩ (x =? k). - destruct n0; lca. - destruct (x / n)%nat; try lca. - destruct n0; lca. + prep_matrix_equivalence. + unfold basis_vector, kron. + intros i j Hi Hj. + rewrite ket0_eqb. + rewrite Cmult_if_if_1_l. + replace j with 0 by lia. + simpl_bools. + symmetry. + rewrite (eqb_iff_div_mod_eqb n). + rewrite (Nat.mod_small k n), (Nat.div_small k n) by easy. + bdestructΩ'. Qed. Lemma basis_vector_prepend_1 : forall n k, @@ -885,52 +1215,38 @@ Lemma basis_vector_prepend_1 : forall n k, ∣1⟩ ⊗ basis_vector n k = basis_vector (2 * n) (k + n). Proof. intros. - unfold basis_vector; solve_matrix. - all: repeat rewrite andb_true_r. - specialize (Nat.div_mod x n H) as DM. - destruct (x / n)%nat. - rewrite Nat.mul_0_r, Nat.add_0_l in DM. - assert (x < n)%nat. - rewrite DM. apply Nat.mod_upper_bound; auto. - bdestructΩ (x =? k + n)%nat. - lca. - destruct n0. - bdestruct (x mod n =? k). - bdestructΩ (x =? k + n); lca. - bdestructΩ (x =? k + n); lca. - assert (x >= 2 * n)%nat. - assert (n * S (S n0) >= 2 * n)%nat. - clear. induction n0; lia. - lia. - bdestructΩ (x =? k + n); lca. - destruct (x / n)%nat; try lca. - destruct n0; lca. -Qed. - -Local Opaque Nat.mul Nat.div Nat.modulo. + intros. + prep_matrix_equivalence. + unfold basis_vector, kron. + intros i j Hi Hj. + rewrite ket1_eqb. + rewrite Cmult_if_if_1_l. + replace j with 0 by lia. + simpl_bools. + symmetry. + rewrite (eqb_iff_div_mod_eqb n). + replace ((k + n) / n) with 1 + by (symmetry; rewrite Kronecker.div_eq_iff; lia). + rewrite (mod_n_to_2n (k + n)) by lia. + bdestructΩ'. +Qed. + Lemma basis_vector_append_0 : forall n k, n <> 0 -> k < n -> basis_vector n k ⊗ ∣0⟩ = basis_vector (2 * n) (2 * k). Proof. intros. - unfold basis_vector; solve_matrix. - rewrite Nat.div_1_r. - bdestruct (y =? 0); subst. - 2: repeat rewrite andb_false_r; lca. - bdestruct (x =? 2 * k); subst. - rewrite Nat.mul_comm. - rewrite Nat.div_mul by auto. - rewrite Nat.eqb_refl. - rewrite Nat.mod_mul, Nat.mod_0_l by auto. - lca. - bdestruct (x / 2 =? k); simpl; try lca. - destruct (x mod 2) eqn:m. - contradict H1. - rewrite <- H2. - apply Nat.div_exact; auto. - destruct n0; try lca. - rewrite Nat.mod_0_l by auto. - lca. + apply mat_equiv_eq; [auto using WF_Matrix_dim_change with wf_db zarith..|]. + unfold basis_vector, kron. + intros i j Hi Hj. + rewrite ket0_eqb. + rewrite Cmult_if_if_1_l. + replace j with 0 by lia. + simpl_bools. + symmetry. + rewrite (eqb_iff_div_mod_eqb 2). + rewrite Nat.mul_comm, Nat.Div0.mod_mul, Nat.div_mul by easy. + now rewrite andb_comm. Qed. Lemma basis_vector_append_1 : forall n k, @@ -938,49 +1254,30 @@ Lemma basis_vector_append_1 : forall n k, basis_vector n k ⊗ ∣1⟩ = basis_vector (2 * n) (2 * k + 1). Proof. intros. - unfold basis_vector; solve_matrix. - rewrite Nat.div_1_r. - bdestruct (y =? 0); subst. - 2: repeat rewrite andb_false_r; lca. - bdestruct (x =? 2 * k + 1); subst. - rewrite Nat.mul_comm. - rewrite Nat.div_add_l by auto. - replace (1 / 2) with 0 by auto. + apply mat_equiv_eq; [auto using WF_Matrix_dim_change with wf_db zarith..|]. + unfold basis_vector, kron. + intros i j Hi Hj. + rewrite ket1_eqb. + rewrite Cmult_if_if_1_l. + replace j with 0 by lia. + simpl_bools. + symmetry. + rewrite (eqb_iff_div_mod_eqb 2). + rewrite Nat.mul_comm, mod_add_l, Nat.div_add_l by easy. rewrite Nat.add_0_r. - rewrite Nat.eqb_refl. - rewrite Nat.add_comm, Nat.mod_add by auto. - replace (1 mod 2) with 1 by auto. - replace (0 mod 1) with 0 by auto. - lca. - bdestruct (x / 2 =? k); simpl; try lca. - destruct (x mod 2) eqn:m. - replace (0 mod 1) with 0 by auto; lca. - destruct n0; try lca. - contradict H1. - rewrite <- H2. - remember 2 as two. - rewrite <- m. - subst. - apply Nat.div_mod; auto. -Qed. -Local Transparent Nat.mul Nat.div Nat.modulo. + now rewrite andb_comm. +Qed. Lemma kron_n_0_is_0_vector : forall (n:nat), n ⨂ ∣0⟩ = basis_vector (2 ^ n) O. Proof. intros. - induction n. + induction n; [solve_matrix_fast|]. simpl. - prep_matrix_equality. - unfold basis_vector, I. - bdestruct_all; reflexivity. - simpl. - rewrite IHn. replace (1 ^ n)%nat with 1%nat. - rewrite (basis_vector_append_0 (2 ^ n) 0). + rewrite IHn. + replace (1 ^ n)%nat with 1%nat by now rewrite Nat.pow_1_l. + rewrite (basis_vector_append_0 (2 ^ n) 0) by show_nonzero. rewrite Nat.mul_0_r. reflexivity. - apply Nat.pow_nonzero. lia. - apply pow_positive. lia. - rewrite Nat.pow_1_l. reflexivity. Qed. Lemma vkron_to_vsum1 : forall n (c : R), @@ -991,45 +1288,43 @@ Proof. intros n c Hn. destruct n; try lia. induction n. - simpl. - repeat rewrite <- big_sum_extend_r. - Msimpl. - rewrite Rmult_0_r, Cexp_0, Mscale_1_l. - replace (basis_vector 2 0) with ∣0⟩ by solve_matrix. - replace (basis_vector 2 1) with ∣1⟩ by solve_matrix. - reflexivity. - remember (S n) as n'. - rewrite <- vkron_extend_l; auto with wf_db. - replace (shift (fun k : nat => ∣0⟩ .+ Cexp (c * 2 ^ (S n' - k - 1)) .* ∣1⟩) 1) with (fun k : nat => ∣0⟩ .+ Cexp (c * 2 ^ (n' - k - 1)) .* ∣1⟩). - 2: { unfold shift. - apply functional_extensionality; intro k. - replace (S n' - (k + 1) - 1)%nat with (n' - k - 1)%nat by lia. - reflexivity. } - rewrite IHn by lia. - replace (S n' - 0 - 1)%nat with n' by lia. - remember (2 ^ n')%nat as N. - assert (HN: (N > 0)%nat). - subst. apply pow_positive. lia. - replace (2 ^ n')%R with (INR N). - 2: { subst. rewrite pow_INR. simpl INR. replace (1+1)%R with 2%R by lra. - reflexivity. } - replace (2 ^ S n')%nat with (2 * N)%nat. - 2: { subst. unify_pows_two. } - clear - HN. - rewrite kron_plus_distr_r. - rewrite 2 kron_Msum_distr_l. - replace (2 * N) with (N + N) by lia. - rewrite big_sum_sum. - replace (N + N) with (2 * N) by lia. - apply f_equal_gen; try apply f_equal; apply big_sum_eq_bounded; intros. - distribute_scale. - rewrite basis_vector_prepend_0 by lia. - reflexivity. - distribute_scale. - rewrite <- Cexp_add, <- Rmult_plus_distr_l, <- plus_INR. - rewrite basis_vector_prepend_1 by lia. - rewrite Nat.add_comm. - reflexivity. + - simpl. + repeat rewrite <- big_sum_extend_r. + Msimpl. + rewrite Rmult_0_r, Cexp_0, Mscale_1_l. + replace (basis_vector 2 0) with ∣0⟩ by solve_matrix_fast. + replace (basis_vector 2 1) with ∣1⟩ by solve_matrix_fast. + reflexivity. + - remember (S n) as n'. + rewrite <- vkron_extend_l; auto with wf_db. + replace (shift (fun k => ∣0⟩ .+ Cexp (c * 2 ^ (S n' - k - 1)) .* ∣1⟩) 1) + with (fun k => ∣0⟩ .+ Cexp (c * 2 ^ (n' - k - 1)) .* ∣1⟩). + 2: { unfold shift. + apply functional_extensionality; intro k. + replace (S n' - (k + 1) - 1)%nat with (n' - k - 1)%nat by lia. + reflexivity. } + rewrite IHn by lia. + replace (S n' - 0 - 1)%nat with n' by lia. + remember (2 ^ n')%nat as N. + assert (HN: (N > 0)%nat) by (subst; show_nonzero). + replace (2 ^ n')%R with (INR N). + 2: { subst. rewrite pow_INR. f_equal; lra. } + replace (2 ^ S n')%nat with (2 * N)%nat by (subst; unify_pows_two). + clear - HN. + rewrite kron_plus_distr_r. + rewrite 2 kron_Msum_distr_l. + replace (2 * N) with (N + N) by lia. + rewrite big_sum_sum. + replace (N + N) with (2 * N) by lia. + apply f_equal_gen; try apply f_equal; apply big_sum_eq_bounded; intros. + + distribute_scale. + rewrite basis_vector_prepend_0 by lia. + reflexivity. + + distribute_scale. + rewrite <- Cexp_add, <- Rmult_plus_distr_l, <- plus_INR. + rewrite basis_vector_prepend_1 by lia. + rewrite Nat.add_comm. + reflexivity. Qed. Local Open Scope R_scope. @@ -1051,8 +1346,8 @@ Proof. Msimpl. unfold nat_to_funbool; simpl. rewrite 2 update_index_eq. - replace (basis_vector 2 0) with ∣0⟩ by solve_matrix. - replace (basis_vector 2 1) with ∣1⟩ by solve_matrix. + replace (basis_vector 2 0) with ∣0⟩ by solve_matrix_fast. + replace (basis_vector 2 1) with ∣1⟩ by solve_matrix_fast. destruct (f O); simpl; restore_dims; lma. - remember (S n) as n'. simpl vkron. @@ -1107,10 +1402,7 @@ Local Transparent Nat.mul. Lemma H_spec : (* slightly different from hadamard_on_basis_state *) forall b : bool, hadamard × ∣ b ⟩ = / √ 2 .* (∣ 0 ⟩ .+ (-1)^b .* ∣ 1 ⟩). Proof. - intro b. - destruct b; simpl; autorewrite with ket_db. - replace (/ √ 2 * (-1 * 1))%C with (- / √ 2)%C by lca. - reflexivity. reflexivity. + apply hadamard_on_ket. Qed. Lemma H_kron_n_spec : forall n x, (n > 0)%nat -> @@ -1151,3 +1443,283 @@ Proof. induction n; try reflexivity. simpl. rewrite IHn. reflexivity. Qed. + + +(* Generalizing vkron to larger matrices *) +Fixpoint big_kron' (ns ms : nat -> nat) + (As : forall i, Matrix (2 ^ ns i) (2 ^ ms i)) (n : nat) : + Matrix (2 ^ big_sum ns n) (2 ^ (big_sum ms n)) := + match n with + | O => I (2 ^ 0) + | S n' => + big_kron' ns ms As n' ⊗ As n' + end. + +Lemma WF_big_kron' ns ms As n + (HAs : forall k, (k < n)%nat -> WF_Matrix (As k)) : + WF_Matrix (big_kron' ns ms As n). +Proof. induction n; cbn; auto_wf. Qed. + +#[export] Hint Resolve WF_big_kron' : wf_db. + +Lemma big_kron'_eq_bounded ns ms As Bs n + (HAB : forall k, (k < n)%nat -> As k = Bs k) : + big_kron' ns ms As n = big_kron' ns ms Bs n. +Proof. + induction n; [easy|]. + cbn; f_equal; auto. +Qed. + +Lemma big_kron'_Mmult ns ms os As Bs n + (HAs : forall k, (k < n)%nat -> WF_Matrix (As k)) + (HBs : forall k, (k < n)%nat -> WF_Matrix (Bs k)) : + big_kron' ns ms As n × big_kron' ms os Bs n = + big_kron' ns os (fun i => As i × Bs i) n. +Proof. + induction n; [apply Mmult_1_l; auto_wf|]. + cbn. + restore_dims. + rewrite kron_mixed_product. + f_equal. + apply IHn; auto. +Qed. + +Lemma big_kron'_transpose ns ms As n : + (big_kron' ns ms As n) ⊤%M = + big_kron' ms ns (fun k => (As k) ⊤%M) n. +Proof. + induction n; cbn. + - apply id_transpose_eq. + - change ((?A ⊗ ?B) ⊤%M) with + (transpose A ⊗ transpose B). + f_equal. + auto. +Qed. + +Lemma big_kron'_transpose' ns ms As n n' m' : + @transpose n' m' (big_kron' ns ms As n) = + big_kron' ms ns (fun k => (As k) ⊤%M) n. +Proof. apply big_kron'_transpose. Qed. + +Lemma big_kron'_adjoint ns ms As n : + (big_kron' ns ms As n) † = + big_kron' ms ns (fun k => (As k) †) n. +Proof. + induction n; cbn. + - apply id_adjoint_eq. + - restore_dims. + rewrite kron_adjoint. + f_equal. + auto. +Qed. + +Lemma big_kron'_id ns As n + (HAs : forall k, (k < n)%nat -> (As k) = I (2 ^ (ns k))) : + big_kron' ns ns As n = I (2 ^ (big_sum ns n)). +Proof. + induction n; [easy|]. + simpl. + rewrite IHn by auto. + rewrite HAs by auto. + rewrite id_kron. + now unify_pows_two. +Qed. + +Lemma big_kron'_unitary ns As n + (HAs : forall k, (k < n)%nat -> WF_Unitary (As k)) : + WF_Unitary (big_kron' ns ns As n). +Proof. + pose proof (fun k Hk => proj1 (HAs k Hk)) as HAs'. + split; [auto_wf|]. + rewrite big_kron'_adjoint. + rewrite big_kron'_Mmult by (intros; auto_wf). + apply big_kron'_id. + intros k Hk. + now apply HAs. +Qed. + +Lemma f_to_vec_big_split n ns f : + f_to_vec (big_sum ns n) f = + big_kron' ns (fun _ => O) + (fun i => f_to_vec (ns i) (fun k => f (big_sum ns i + k)%nat)) n. +Proof. + induction n; [easy|]. + cbn. + rewrite <- IHn. + rewrite f_to_vec_split'_eq. + f_equal. + now rewrite big_sum_0. +Qed. + +Lemma big_kron'_f_to_vec ns ms As n f + (HAs : forall k, (k < n)%nat -> WF_Matrix (As k)): + big_kron' ns ms As n × f_to_vec (big_sum ms n) f = + big_kron' ns (fun _ => O) + (fun i => As i × f_to_vec (ms i) (fun k => f (big_sum ms i + k)%nat)) n. +Proof. + rewrite f_to_vec_big_split. + restore_dims. + rewrite big_kron'_Mmult by (intros; auto_wf). + easy. +Qed. + +Lemma big_kron'_split_add ns ms As n n' + (HAs : forall k, (k < n + n')%nat -> WF_Matrix (As k)) : + big_kron' ns ms As (n + n') = + big_kron' ns ms As n ⊗ + big_kron' (fun k => ns (n + k)%nat) (fun k => ms (n + k)%nat) + (fun k => As (n + k)%nat) n'. +Proof. + induction n'. + - simpl. + now rewrite Nat.add_0_r, kron_1_r. + - rewrite Nat.add_succ_r. simpl. + rewrite IHn' by auto with zarith. + restore_dims. + apply kron_assoc; auto_wf. +Qed. + +Lemma big_kron'_split_add' ns ms As n n' + (HAs : forall k, (k < n + n')%nat -> WF_Matrix (As k)) : + big_kron' ns ms As (n + n') = + big_kron' ns ms As n ⊗ + big_kron' (fun k => ns (k + n)%nat) (fun k => ms (k + n)%nat) + (fun k => As (k + n)%nat) n'. +Proof. + rewrite big_kron'_split_add by auto. + f_equal; + [f_equal; apply big_sum_eq_bounded; intros ? ?; f_equal; lia..|]. + induction n'; [easy|]. + simpl. + rewrite IHn' by auto with zarith. + rewrite (Nat.add_comm n n'). + f_equal; + f_equal; apply big_sum_eq_bounded; + intros ? ?; f_equal; lia. +Qed. + +Lemma big_kron'_split n i (Hk : (i < n)%nat) ns ms As + (HAs : forall k, (k < n)%nat -> WF_Matrix (As k)) : + big_kron' ns ms As n = + big_kron' ns ms As i ⊗ As i ⊗ + big_kron' (fun k => ns (k + 1 + i)%nat) (fun k => ms (k + 1 + i)%nat) + (fun k => As (k + 1 + i)%nat) (n - 1 - i). +Proof. + fill_differences. + replace (i + 1 + x - 1 - i)%nat with x by lia. + rewrite 2!big_kron'_split_add' by auto with zarith. + f_equal. + 1, 2: rewrite Nat.add_comm, <- Nat.pow_add_r; simpl; f_equal; lia. + 1, 2: f_equal; apply big_sum_eq_bounded; intros ? ?; f_equal; lia. + - f_equal. + cbn. + apply kron_1_l, HAs; lia. + - induction x; [easy|]. + simpl. + rewrite IHx, (Nat.add_comm i 1), Nat.add_assoc by auto with zarith. + f_equal. + 1, 2: f_equal; apply big_sum_eq_bounded; intros ? ?; f_equal; lia. +Qed. + +Lemma big_kron'_0_0_eq_up_to_fswap + As n x y (Hx : (x < n)%nat) (Hy : (y < n)%nat) + (HAs : forall k, (k < n)%nat -> WF_Matrix (As k)) + f (Hf : perm_bounded n f) : + big_kron' (fun _ => O) (fun _ => O) (As ∘ f)%prg n = + big_kron' (fun _ => O) (fun _ => O) (As ∘ fswap f x y)%prg n. +Proof. + bdestruct (x =? y); + [apply big_kron'_eq_bounded; unfold fswap, compose; intros; + bdestructΩ'|]. + assert (Hfs : perm_bounded n (fswap f x y)) + by (intros k Hk; unfold fswap; bdestructΩ'; auto). + bdestruct (x WF_Matrix (As k)) : + big_kron' (fun _ => O) (fun _ => O) As n = + big_kron' (fun _ => O) (fun _ => O) (As ∘ f)%prg n. +Proof. + intros. + generalize dependent f. + induction n; + [reflexivity|]. + intros f Hf. + pose proof Hf as [g Hg]. + destruct (Hg n) as [_ [H1' [_ H2']]]; try lia. + symmetry. + rewrite (big_kron'_0_0_eq_up_to_fswap _ _ (g n) n) + by auto with perm_bounded_db. + simpl. + unfold compose. + rewrite fswap_simpl2. + unfold compose. + rewrite H2'. + specialize (IHn ltac:(auto) (fswap f (g n) n)). + rewrite IHn by + (apply fswap_at_boundary_permutation; auto). + reflexivity. +Qed. \ No newline at end of file diff --git a/coq-quantumlib.opam b/coq-quantumlib.opam index 98f6d97..e6ee56d 100644 --- a/coq-quantumlib.opam +++ b/coq-quantumlib.opam @@ -1,6 +1,6 @@ # This file is generated by dune, edit dune-project instead opam-version: "2.0" -version: "1.5.1" +version: "1.6.0" synopsis: "Coq library for reasoning about quantum programs" description: """ inQWIRE's QuantumLib is a Coq library for reasoning @@ -14,7 +14,7 @@ doc: "https://inqwire.github.io/QuantumLib/toc.html" bug-reports: "https://github.com/inQWIRE/QuantumLib/issues" depends: [ "dune" {>= "2.8"} - "coq" {>= "8.16" < "8.20"} + "coq" {>= "8.16" & < "8.20"} "odoc" {with-doc} ] build: [ @@ -31,4 +31,4 @@ build: [ "@doc" {with-doc} ] ] -dev-repo: "git+https://github.com/inQWIRE/QuantumLib.git" \ No newline at end of file +dev-repo: "git+https://github.com/inQWIRE/QuantumLib.git" diff --git a/dune-project b/dune-project index 7a9a3a6..0812e4d 100644 --- a/dune-project +++ b/dune-project @@ -1,6 +1,6 @@ (lang dune 2.8) (name coq-quantumlib) -(version 1.5.1) +(version 1.6.0) (using coq 0.2) (generate_opam_files true) @@ -18,4 +18,4 @@ "\| about quantum computation and quantum programs. ) (depends - (coq (>= 8.12)))) + (coq (and (>= 8.16) (< 8.20)))))