Skip to content

Commit

Permalink
Add direct Rust VersorGate tests
Browse files Browse the repository at this point in the history
... which were necessary because I'd borked the matrix calculations and
forgotten to write any tests of them.
  • Loading branch information
jakelishman committed Jan 10, 2025
1 parent da462a4 commit df1a3c2
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 8 deletions.
94 changes: 87 additions & 7 deletions crates/accelerate/src/qi/versor_gate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,16 +306,17 @@ impl VersorGate {
//
// and the matrix is in SU(2) with determinant 1. We can find the phase angle `x` via the
// determinant of the given, which we've been promised has magnitude 1, so we can phase it
// out of the matrix terms by multiplying by the conjugate. The versor representation is
// then just `(x, [a, b, c, d])`.
// out of the matrix terms by multiplying by the root of the conjugate (since `det(aA) =
// a^dim(A) det(A)`). The versor representation is then just `(x/2, [a, b, c, d])`.
let det = matrix.get(0, 0) * matrix.get(1, 1) - matrix.get(0, 1) * matrix.get(1, 0);
let inv_rot = det.conj().sqrt();
Self {
phase: det.arg(),
phase: 0.5 * det.arg(),
action: UnitQuaternion::new_unchecked(Quaternion::new(
(det.conj() * matrix.get(0, 0)).re,
(det.conj() * matrix.get(0, 0)).im,
(det.conj() * matrix.get(0, 1)).re,
(det.conj() * matrix.get(0, 1)).im,
(inv_rot * matrix.get(0, 0)).re,
(inv_rot * matrix.get(0, 0)).im,
(inv_rot * matrix.get(0, 1)).re,
(inv_rot * matrix.get(0, 1)).im,
)),
}
}
Expand Down Expand Up @@ -443,3 +444,82 @@ fn unitary_frobenius_distance_square<M: Matrix1q>(matrix: &M) -> f64 {
matrix.get(0, 0) * matrix.get(1, 0).conj() + matrix.get(0, 1) * matrix.get(1, 1).conj();
(topleft - 1.).powi(2) + (botright - 1.).powi(2) + 2. * off.norm_sqr()
}

#[cfg(test)]
mod test {
use super::*;

use approx::AbsDiffEq;
use ndarray::aview2;
use qiskit_circuit::operations::{Operation, Param, StandardGate, STANDARD_GATE_SIZE};

fn all_1q_gates() -> Vec<StandardGate> {
(0..STANDARD_GATE_SIZE as u8)
.filter_map(|x| {
::bytemuck::checked::try_cast::<_, StandardGate>(x)
.ok()
.filter(|gate| gate.num_qubits() == 1)
})
.collect()
}

#[test]
fn each_1q_gate_has_correct_matrix() {
let params = [0.25, -0.75, 1.25, 0.5].map(Param::Float);
let mut fails = Vec::new();
for gate in all_1q_gates() {
let params = &params[0..gate.num_params() as usize];
let direct_matrix = gate.matrix(params).unwrap();
let versor_matrix = VersorGate::from_standard(gate, params)
.unwrap()
.matrix_contiguous();
if direct_matrix.abs_diff_ne(&aview2(&versor_matrix), 1e-15) {
fails.push((gate, direct_matrix, versor_matrix));
}
}
assert_eq!(fails, [])
}

#[test]
fn can_roundtrip_1q_gate_from_matrix() {
let params = [0.25, -0.75, 1.25, 0.5].map(Param::Float);
let mut fails = Vec::new();
for gate in all_1q_gates() {
let params = &params[0..gate.num_params() as usize];
let direct_matrix = gate.matrix(params).unwrap();
let versor_matrix = VersorGate::from_ndarray(&direct_matrix.view(), 1e-15)
.unwrap()
.matrix_contiguous();
if direct_matrix.abs_diff_ne(&aview2(&versor_matrix), 1e-15) {
fails.push((gate, direct_matrix, versor_matrix));
}
}
assert_eq!(fails, [])
}

#[test]
fn pairwise_multiplication_gives_correct_matrices() {
// We have two pairs just so in the (x, x) case of iteration we're including two different
// gates to make sure that any non-commutation is accounted for.
let left_params = [0.25, -0.75, 1.25, 0.5].map(Param::Float);
let right_params = [0.5, 1.25, -0.75, 0.25].map(Param::Float);

let mut fails = Vec::new();
for (left, right) in all_1q_gates().into_iter().zip(all_1q_gates()) {
let left_params = &left_params[0..left.num_params() as usize];
let right_params = &right_params[0..right.num_params() as usize];

let direct_matrix = left
.matrix(left_params)
.unwrap()
.dot(&right.matrix(right_params).unwrap());
let versor_matrix = (VersorGate::from_standard(left, left_params).unwrap()
* VersorGate::from_standard(right, right_params).unwrap())
.matrix_contiguous();
if direct_matrix.abs_diff_ne(&aview2(&versor_matrix), 1e-15) {
fails.push((left, right, direct_matrix, versor_matrix));
}
}
assert_eq!(fails, [])
}
}
2 changes: 1 addition & 1 deletion crates/circuit/src/operations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ unsafe impl ::bytemuck::CheckedBitPattern for StandardGate {
type Bits = u8;

fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
*bits < 53
*bits < (STANDARD_GATE_SIZE as u8)
}
}
unsafe impl ::bytemuck::NoUninit for StandardGate {}
Expand Down

0 comments on commit df1a3c2

Please sign in to comment.