Skip to content

Commit

Permalink
some suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
Patrick Hopf committed Aug 6, 2024
1 parent 1ce4f3c commit 4f6a5a3
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 0 deletions.
112 changes: 112 additions & 0 deletions src/qutip_qoc/_rl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""
This module contains ...
"""
import qutip as qt
from qutip import Qobj, QobjEvo

import numpy as np

import gymnasium as gym
from gymnasium import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env


class _RL(gym.Env): # TODO: this should be similar to your GymQubitEnv(gym.Env) implementation
"""
Class for storing a control problem and ...
"""

def __init__(
self,
objective,
time_interval,
time_options,
control_parameters,
alg_kwargs,
guess_params,
**integrator_kwargs,
):
super().__init__() # TODO: super init your gym environment here

# ------------------------------- copied from _GOAT class -------------------------------

# TODO: you dont have to use (or keep them) if you don't need the following attributes
# this is just an inspiration how to extract information from the input

self._Hd = objective.H[0]
self._Hc_lst = objective.H[1:]

self._control_parameters = control_parameters
self._guess_params = guess_params
self._H = self._prepare_generator()

self._initial = objective.initial
self._target = objective.target

self._evo_time = time_interval.evo_time

# inferred attributes
self._norm_fac = 1 / self._target.norm()

# integrator options
self._integrator_kwargs = integrator_kwargs

self._rtol = self._integrator_kwargs.get("rtol", 1e-5)
self._atol = self._integrator_kwargs.get("atol", 1e-5)

# choose solver and fidelity type according to problem
if self._Hd.issuper:
self._fid_type = alg_kwargs.get("fid_type", "TRACEDIFF")
self._solver = qt.MESolver(H=self._H, options=self._integrator_kwargs)

else:
self._fid_type = alg_kwargs.get("fid_type", "PSU")
self._solver = qt.SESolver(H=self._H, options=self._integrator_kwargs)

self.infidelity = self._infid # TODO: should be used to calculate the reward

# ----------------------------------------------------------------------------------------
# TODO: set up your gym environment as you did correctly in post10
self.max_episode_time = time_interval.evo_time # maximum time for an episode
self.max_steps = time_interval.n_tslots # maximum number of steps in an episode
self.step_duration = time_interval.tslots[-1] / time_interval.n_tslots # step duration for mesvole()
...


# ----------------------------------------------------------------------------------------

def _infid(self, params):
"""
Calculate infidelity to be minimized
"""
X = self._solver.run(
self._initial, [0.0, self._evo_time], args={"p": params}
).final_state

if self._fid_type == "TRACEDIFF":
diff = X - self._target
# to prevent if/else in qobj.dag() and qobj.tr()
diff_dag = Qobj(diff.data.adjoint(), dims=diff.dims)
g = 1 / 2 * (diff_dag * diff).data.trace()
infid = np.real(self._norm_fac * g)
else:
g = self._norm_fac * self._target.overlap(X)
if self._fid_type == "PSU": # f_PSU (drop global phase)
infid = 1 - np.abs(g)
elif self._fid_type == "SU": # f_SU (incl global phase)
infid = 1 - np.real(g)

return infid

# TODO: don't hesitate to add the required methods for your rl environment

def step(self, action):
...

def train(self):
...

def result(self):
# TODO: return qoc.Result object with the optimized pulse amplitudes
...
17 changes: 17 additions & 0 deletions src/qutip_qoc/pulse_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from qutip_qoc._optimizer import _global_local_optimization
from qutip_qoc._time import _TimeInterval
from qutip_qoc._rl import _RL

__all__ = ["optimize_pulses"]

Expand Down Expand Up @@ -348,6 +349,22 @@ def optimize_pulses(

qtrl_optimizers.append(qtrl_optimizer)

# TODO: we can deal with proper handling later
if alg == "RL":
rl_env = _RL(
objectives,
control_parameters,
time_interval,
time_options,
algorithm_kwargs,
optimizer_kwargs,
minimizer_kwargs,
integrator_kwargs,
qtrl_optimizers,
)
rl_env.train()
return rl_env.result()

return _global_local_optimization(
objectives,
control_parameters,
Expand Down
37 changes: 37 additions & 0 deletions tests/test_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,41 @@ def sin_z_jax(t, r, **kwargs):
algorithm_kwargs={"alg": "CRAB", "fid_err_targ": 0.01, "fix_frequency": False},
)

# ----------------------- RL --------------------
# TODO: this is the input for optimiz_pulses() function
# you can use this routine to test your implementation

# state to state transfer
init = qt.basis(2, 0)
target = qt.basis(2, 1)

H_c = [qt.sigmax(), qt.sigmay(), qt.sigmaz()] # control Hamiltonians

w, d, y = 0.1, 1.0, 0.1
H_d = 1 / 2 * (w * qt.sigmaz() + d * qt.sigmax()) # drift Hamiltonian

H = [H_d] + H_c # total Hamiltonian

state2state_rl = Case(
objectives=[Objective(initial, H, target)],
control_parameters={"bounds": [-13, 13]}, # TODO: for now only consider bounds
tlist=np.linspace(0, 10, 100), # TODO: derive single step duration and max evo time / max num steps from this
algorithm_kwargs={
"fid_err_targ": 0.01,
"alg": "RL",
"max_iter": 100,
}
)

# TODO: no big difference for unitary evolution

initial = qt.qeye(2) # Identity
target = qt.gates.hadamard_transform()

unitary_rl = state2state_rl._replace(
objectives=[Objective(initial, H, target)],
)


@pytest.fixture(
params=[
Expand All @@ -160,6 +195,8 @@ def sin_z_jax(t, r, **kwargs):
pytest.param(state2state_param_crab, id="State to state (param. CRAB)"),
pytest.param(state2state_goat, id="State to state (GOAT)"),
pytest.param(state2state_jax, id="State to state (JAX)"),
pytest.param(state2state_rl, id="State to state (RL)"),
pytest.param(unitary_rl, id="Unitary (RL)"),
]
)
def tst(request):
Expand Down

0 comments on commit 4f6a5a3

Please sign in to comment.