Skip to content

Commit

Permalink
Merge pull request #64 from rochisha0/test-new
Browse files Browse the repository at this point in the history
Create tests for functions in qutip
  • Loading branch information
Ericgig authored Aug 26, 2024
2 parents 2db4223 + e7fd203 commit 0cc6256
Show file tree
Hide file tree
Showing 4 changed files with 244 additions and 0 deletions.
44 changes: 44 additions & 0 deletions tests/test_qutip/test_entropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest
import jax.numpy as jnp
from jax import jit, grad
from qutip import bell_state
from qutip.entropy import (entropy_vn, entropy_linear, entropy_mutual, concurrence,
entropy_conditional, participation_ratio)
import qutip.settings
import qutip_jax

qutip.settings.core["auto_real_casting"] = False
qutip_jax.set_as_default()
tol = 1e-6 # Tolerance for assertion

with qutip.CoreOptions(default_dtype="jax"):
bell_state = bell_state("10")
bell_dm = bell_state * bell_state.dag()
dm = qutip.rand_dm([5, 5], distribution="pure")

@pytest.mark.parametrize("func, name, args", [
(entropy_vn, "entropy_vn", (bell_dm,)),
(entropy_linear, "entropy_linear", (bell_dm,)),
(concurrence, "concurrence", (bell_dm,)),
(participation_ratio, "participation_ratio", (bell_dm,))
])

def test_jit(func, name, args):
func_jit = jit(func)
result = func(*args)
result_jit = func_jit(*args)
assert jnp.abs(result - result_jit) < tol

@pytest.mark.parametrize("func, name, args", [
(entropy_vn, "entropy_vn", (bell_dm,)),
(entropy_linear, "entropy_linear", (bell_dm,)),
(entropy_mutual, "entropy_mutual", (dm, [0], [1])),
(concurrence, "concurrence", (bell_dm,)),
(entropy_conditional, "entropy_conditional", (bell_dm, 0)),
])
def test_grad(func, name, args):
func_grad = grad(func)
result_grad = func_grad(*args)
assert result_grad is not None


61 changes: 61 additions & 0 deletions tests/test_qutip/test_mcsolve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import pytest
import jax
import jax.numpy as jnp
import qutip as qt
import qutip_jax as qjax
from qutip import mcsolve
from functools import partial

# Use JAX backend for QuTiP
qjax.set_as_default()

# Define time-dependent functions
@partial(jax.jit, static_argnames=("omega",))
def H_1_coeff(t, omega):
return 2.0 * jnp.pi * 0.25 * jnp.cos(2.0 * omega * t)

# Test setup for gradient calculation
def setup_system(size=2):
a = qt.tensor(qt.destroy(size), qt.qeye(2)).to('jaxdia')
sm = qt.qeye(size).to('jaxdia') & qt.sigmax().to('jaxdia')

# Define the Hamiltonian
H_0 = 2.0 * jnp.pi * a.dag() * a + 2.0 * jnp.pi * sm.dag() * sm
H_1_op = sm * a.dag() + sm.dag() * a

H = [H_0, [H_1_op, qt.coefficient(H_1_coeff, args={"omega": 1.0})]]

state = qt.basis(size, size - 1).to('jax') & qt.basis(2, 1).to('jax')

# Define collapse operators and observables
c_ops = [jnp.sqrt(0.1) * a]
e_ops = [a.dag() * a, sm.dag() * sm]

# Time list
tlist = jnp.linspace(0.0, 1.0, 101)

return H, state, tlist, c_ops, e_ops

# Function for which we want to compute the gradient
def f(omega, H, state, tlist, c_ops, e_ops):
result = mcsolve(
H, state, tlist, c_ops, e_ops, ntraj=10,
args={"omega": omega},
options={"method": "diffrax"}
)

return result.expect[0][-1].real

# Pytest test case for gradient computation
@pytest.mark.parametrize("omega_val", [2.0])
def test_gradient_mcsolve(omega_val):
H, state, tlist, c_ops, e_ops = setup_system(size=10)

# Compute the gradient with respect to omega
grad_func = jax.grad(lambda omega: f(omega, H, state, tlist, c_ops, e_ops))
gradient = grad_func(omega_val)

# Check if the gradient is not None and has the correct shape
assert gradient is not None
assert gradient.shape == ()
assert jnp.isfinite(gradient)
43 changes: 43 additions & 0 deletions tests/test_qutip/test_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest
import jax.numpy as jnp
from jax import jit, grad
from qutip import basis
from qutip.core.metrics import (fidelity, tracedist, bures_dist, bures_angle,
hellinger_dist, hilbert_dist)
import qutip.settings
import qutip_jax

qutip.settings.core["auto_real_casting"] = False
qutip_jax.set_as_default()
tol = 1e-6 # Tolerance for assertion

with qutip.CoreOptions(default_dtype="jax"):
rho1 = qutip.rand_dm(dimensions=5)
rho2 = qutip.rand_dm(dimensions=5)
ket_state = basis(2, 0)
oper_state = qutip.rand_dm(2)

@pytest.mark.parametrize("func, name, args", [
(fidelity, "fidelity", (rho1, rho2)),
(tracedist, "tracedist", (rho1, rho2)),
(bures_dist, "bures_dist", (rho1, rho2)),
(bures_angle, "bures_angle", (rho1, rho2)),
(hellinger_dist, "hellinger_dist", (rho1, rho2)),
(hilbert_dist, "hilbert_dist", (rho1, rho2)),
])
def test_jit(func, name, args):
func_jit = jit(func)
result = func(*args)
result_jit = func_jit(*args)
assert jnp.abs(result - result_jit) < tol

@pytest.mark.parametrize("func, name, args", [
(fidelity, "fidelity", (ket_state, oper_state)),
(tracedist, "tracedist", (rho1, rho2)),
(hellinger_dist, "hellinger_dist", (ket_state, oper_state)),
])
def test_grad(func, name, args):
func_grad = grad(func)
result = func(*args)
result_grad = func_grad(*args)
assert result_grad is not None
96 changes: 96 additions & 0 deletions tests/test_qutip/test_qobj.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import pytest
import jax.numpy as jnp
from jax import jit, grad
from qutip import Qobj, basis, rand_dm, sigmax, identity, tensor, expect
import qutip.settings
import qutip_jax

# Set JAX backend for QuTiP
qutip.settings.core["auto_real_casting"] = False
qutip_jax.set_as_default()
tol = 1e-6 # Tolerance for assertion

# Initialize quantum objects for testing
with qutip.CoreOptions(default_dtype="jax"):
ket = basis(2, 0)
bra = ket.dag()
op1 = rand_dm(2)
identity_op = identity(2)
composite_op = tensor(op1, identity_op)


# Test case for Qobj functions with jax.jit
@pytest.mark.parametrize("func_name, func", [
("copy", lambda x: x.copy()),
("conj", lambda x: x.conj()),
("contract", lambda x: x.contract()),
("cosm", lambda x: x.cosm()),
("dag", lambda x: x.dag()),
("eigenenergies", lambda x: x.eigenenergies()),
("expm", lambda x: x.expm()),
("inv", lambda x: x.inv()),
("matrix_element", lambda x: x.matrix_element(ket, ket)),
("norm", lambda x: x.norm()),
("overlap", lambda x: x.overlap(op1)),
("ptrace", lambda x: x.ptrace([0])),
("purity", lambda x: x.purity()),
("sinm", lambda x: x.sinm()),
("sqrtm", lambda x: x.sqrtm()),
("tr", lambda x: x.tr()),
("trans", lambda x: x.trans()),
("transform", lambda x: x.transform(identity_op)),
("unit", lambda x: x.unit())
])
def test_qobj_jit(func_name, func):
# Create a jitted function using the given Qobj function
def jit_func(op):
return func(op)

# Apply jit to the function
func_jit = jit(jit_func)
result_jit = func_jit(op1)

# Check if jit result is not None
assert result_jit is not None

@pytest.mark.parametrize("func_name, func", [
("eigenenergies", lambda x: jnp.sum(x.eigenenergies())),
("overlap", lambda x: x.overlap(Qobj(jnp.eye(x.shape[0])))),
("purity", lambda x: x.purity()),
("tr", lambda x: x.tr()),
])
def test_qobj_grad_complex(func_name, func):
def grad_func(op1):
result = func(op1)
return jnp.real(result)

# Apply grad to the function
grad_func = grad(grad_func)
grad_result = grad_func(op1)

assert grad_result is not None


@pytest.mark.parametrize("func_name, func", [
("copy", lambda x: x.copy()),
("conj", lambda x: x.conj()),
("contract", lambda x: x.contract()),
("expm", lambda x: x.expm()),
("cosm", lambda x: x.cosm()),
("dag", lambda x: x.dag()),
("inv", lambda x: x.inv()),
("sinm", lambda x: x.sinm()),
("trans", lambda x: x.trans()),
])
def test_qobj_grad_differentiable(func_name, func):
def grad_func(op1):
result = func(op1)
return jnp.real(result.tr())

# Apply grad to the function
grad_func = grad(grad_func)
grad_result = grad_func(op1)

assert grad_result is not None


0 comments on commit 0cc6256

Please sign in to comment.