diff --git a/tests/test_qutip/test_entropy.py b/tests/test_qutip/test_entropy.py new file mode 100644 index 0000000..3ecb9a5 --- /dev/null +++ b/tests/test_qutip/test_entropy.py @@ -0,0 +1,44 @@ +import pytest +import jax.numpy as jnp +from jax import jit, grad +from qutip import bell_state +from qutip.entropy import (entropy_vn, entropy_linear, entropy_mutual, concurrence, + entropy_conditional, participation_ratio) +import qutip.settings +import qutip_jax + +qutip.settings.core["auto_real_casting"] = False +qutip_jax.set_as_default() +tol = 1e-6 # Tolerance for assertion + +with qutip.CoreOptions(default_dtype="jax"): + bell_state = bell_state("10") + bell_dm = bell_state * bell_state.dag() + dm = qutip.rand_dm([5, 5], distribution="pure") + +@pytest.mark.parametrize("func, name, args", [ + (entropy_vn, "entropy_vn", (bell_dm,)), + (entropy_linear, "entropy_linear", (bell_dm,)), + (concurrence, "concurrence", (bell_dm,)), + (participation_ratio, "participation_ratio", (bell_dm,)) +]) + +def test_jit(func, name, args): + func_jit = jit(func) + result = func(*args) + result_jit = func_jit(*args) + assert jnp.abs(result - result_jit) < tol + +@pytest.mark.parametrize("func, name, args", [ + (entropy_vn, "entropy_vn", (bell_dm,)), + (entropy_linear, "entropy_linear", (bell_dm,)), + (entropy_mutual, "entropy_mutual", (dm, [0], [1])), + (concurrence, "concurrence", (bell_dm,)), + (entropy_conditional, "entropy_conditional", (bell_dm, 0)), +]) +def test_grad(func, name, args): + func_grad = grad(func) + result_grad = func_grad(*args) + assert result_grad is not None + + diff --git a/tests/test_qutip/test_mcsolve.py b/tests/test_qutip/test_mcsolve.py new file mode 100644 index 0000000..ac7c095 --- /dev/null +++ b/tests/test_qutip/test_mcsolve.py @@ -0,0 +1,61 @@ +import pytest +import jax +import jax.numpy as jnp +import qutip as qt +import qutip_jax as qjax +from qutip import mcsolve +from functools import partial + +# Use JAX backend for QuTiP +qjax.set_as_default() + +# Define time-dependent functions +@partial(jax.jit, static_argnames=("omega",)) +def H_1_coeff(t, omega): + return 2.0 * jnp.pi * 0.25 * jnp.cos(2.0 * omega * t) + +# Test setup for gradient calculation +def setup_system(size=2): + a = qt.tensor(qt.destroy(size), qt.qeye(2)).to('jaxdia') + sm = qt.qeye(size).to('jaxdia') & qt.sigmax().to('jaxdia') + + # Define the Hamiltonian + H_0 = 2.0 * jnp.pi * a.dag() * a + 2.0 * jnp.pi * sm.dag() * sm + H_1_op = sm * a.dag() + sm.dag() * a + + H = [H_0, [H_1_op, qt.coefficient(H_1_coeff, args={"omega": 1.0})]] + + state = qt.basis(size, size - 1).to('jax') & qt.basis(2, 1).to('jax') + + # Define collapse operators and observables + c_ops = [jnp.sqrt(0.1) * a] + e_ops = [a.dag() * a, sm.dag() * sm] + + # Time list + tlist = jnp.linspace(0.0, 1.0, 101) + + return H, state, tlist, c_ops, e_ops + +# Function for which we want to compute the gradient +def f(omega, H, state, tlist, c_ops, e_ops): + result = mcsolve( + H, state, tlist, c_ops, e_ops, ntraj=10, + args={"omega": omega}, + options={"method": "diffrax"} + ) + + return result.expect[0][-1].real + +# Pytest test case for gradient computation +@pytest.mark.parametrize("omega_val", [2.0]) +def test_gradient_mcsolve(omega_val): + H, state, tlist, c_ops, e_ops = setup_system(size=10) + + # Compute the gradient with respect to omega + grad_func = jax.grad(lambda omega: f(omega, H, state, tlist, c_ops, e_ops)) + gradient = grad_func(omega_val) + + # Check if the gradient is not None and has the correct shape + assert gradient is not None + assert gradient.shape == () + assert jnp.isfinite(gradient) diff --git a/tests/test_qutip/test_metrics.py b/tests/test_qutip/test_metrics.py new file mode 100644 index 0000000..a462183 --- /dev/null +++ b/tests/test_qutip/test_metrics.py @@ -0,0 +1,43 @@ +import pytest +import jax.numpy as jnp +from jax import jit, grad +from qutip import basis +from qutip.core.metrics import (fidelity, tracedist, bures_dist, bures_angle, + hellinger_dist, hilbert_dist) +import qutip.settings +import qutip_jax + +qutip.settings.core["auto_real_casting"] = False +qutip_jax.set_as_default() +tol = 1e-6 # Tolerance for assertion + +with qutip.CoreOptions(default_dtype="jax"): + rho1 = qutip.rand_dm(dimensions=5) + rho2 = qutip.rand_dm(dimensions=5) + ket_state = basis(2, 0) + oper_state = qutip.rand_dm(2) + +@pytest.mark.parametrize("func, name, args", [ + (fidelity, "fidelity", (rho1, rho2)), + (tracedist, "tracedist", (rho1, rho2)), + (bures_dist, "bures_dist", (rho1, rho2)), + (bures_angle, "bures_angle", (rho1, rho2)), + (hellinger_dist, "hellinger_dist", (rho1, rho2)), + (hilbert_dist, "hilbert_dist", (rho1, rho2)), +]) +def test_jit(func, name, args): + func_jit = jit(func) + result = func(*args) + result_jit = func_jit(*args) + assert jnp.abs(result - result_jit) < tol + +@pytest.mark.parametrize("func, name, args", [ + (fidelity, "fidelity", (ket_state, oper_state)), + (tracedist, "tracedist", (rho1, rho2)), + (hellinger_dist, "hellinger_dist", (ket_state, oper_state)), +]) +def test_grad(func, name, args): + func_grad = grad(func) + result = func(*args) + result_grad = func_grad(*args) + assert result_grad is not None diff --git a/tests/test_qutip/test_qobj.py b/tests/test_qutip/test_qobj.py new file mode 100644 index 0000000..059ae9d --- /dev/null +++ b/tests/test_qutip/test_qobj.py @@ -0,0 +1,96 @@ +import pytest +import jax.numpy as jnp +from jax import jit, grad +from qutip import Qobj, basis, rand_dm, sigmax, identity, tensor, expect +import qutip.settings +import qutip_jax + +# Set JAX backend for QuTiP +qutip.settings.core["auto_real_casting"] = False +qutip_jax.set_as_default() +tol = 1e-6 # Tolerance for assertion + +# Initialize quantum objects for testing +with qutip.CoreOptions(default_dtype="jax"): + ket = basis(2, 0) + bra = ket.dag() + op1 = rand_dm(2) + identity_op = identity(2) + composite_op = tensor(op1, identity_op) + + +# Test case for Qobj functions with jax.jit +@pytest.mark.parametrize("func_name, func", [ + ("copy", lambda x: x.copy()), + ("conj", lambda x: x.conj()), + ("contract", lambda x: x.contract()), + ("cosm", lambda x: x.cosm()), + ("dag", lambda x: x.dag()), + ("eigenenergies", lambda x: x.eigenenergies()), + ("expm", lambda x: x.expm()), + ("inv", lambda x: x.inv()), + ("matrix_element", lambda x: x.matrix_element(ket, ket)), + ("norm", lambda x: x.norm()), + ("overlap", lambda x: x.overlap(op1)), + ("ptrace", lambda x: x.ptrace([0])), + ("purity", lambda x: x.purity()), + ("sinm", lambda x: x.sinm()), + ("sqrtm", lambda x: x.sqrtm()), + ("tr", lambda x: x.tr()), + ("trans", lambda x: x.trans()), + ("transform", lambda x: x.transform(identity_op)), + ("unit", lambda x: x.unit()) +]) +def test_qobj_jit(func_name, func): + # Create a jitted function using the given Qobj function + def jit_func(op): + return func(op) + + # Apply jit to the function + func_jit = jit(jit_func) + result_jit = func_jit(op1) + + # Check if jit result is not None + assert result_jit is not None + +@pytest.mark.parametrize("func_name, func", [ + ("eigenenergies", lambda x: jnp.sum(x.eigenenergies())), + ("overlap", lambda x: x.overlap(Qobj(jnp.eye(x.shape[0])))), + ("purity", lambda x: x.purity()), + ("tr", lambda x: x.tr()), +]) +def test_qobj_grad_complex(func_name, func): + def grad_func(op1): + result = func(op1) + return jnp.real(result) + + # Apply grad to the function + grad_func = grad(grad_func) + grad_result = grad_func(op1) + + assert grad_result is not None + + +@pytest.mark.parametrize("func_name, func", [ + ("copy", lambda x: x.copy()), + ("conj", lambda x: x.conj()), + ("contract", lambda x: x.contract()), + ("expm", lambda x: x.expm()), + ("cosm", lambda x: x.cosm()), + ("dag", lambda x: x.dag()), + ("inv", lambda x: x.inv()), + ("sinm", lambda x: x.sinm()), + ("trans", lambda x: x.trans()), +]) +def test_qobj_grad_differentiable(func_name, func): + def grad_func(op1): + result = func(op1) + return jnp.real(result.tr()) + + # Apply grad to the function + grad_func = grad(grad_func) + grad_result = grad_func(op1) + + assert grad_result is not None + +