Skip to content

Commit

Permalink
forward simulation codepath that computes gradients seems to work. Ha…
Browse files Browse the repository at this point in the history
…vn`t used it to speed up derivative computations yet.
  • Loading branch information
rileyjmurray committed Feb 1, 2024
1 parent 243b757 commit 0bea829
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 15 deletions.
17 changes: 9 additions & 8 deletions pygsti/forwardsims/torchfwdsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _build_torch_cache(model: ExplicitOpModel, layout):

prep_label = spc.circuit_without_povm[0]
op_labels = spc.circuit_without_povm[1:]
effect_labels = spc.full_effect_labels
povm_label = spc.povm_label

rho = model.circuit_layer_operator(prep_label, typ='prep')
""" ^
Expand All @@ -89,7 +89,7 @@ def _build_torch_cache(model: ExplicitOpModel, layout):
<class 'pygsti.modelmembers.operations.linearop.LinearOperator'>
<class 'pygsti.modelmembers.modelmember.ModelMember'>
"""
povm = model.circuit_layer_operator(spc.povm_label, 'povm')
povm = model.circuit_layer_operator(povm_label, 'povm')
"""
<class 'pygsti.modelmembers.povms.tppovm.TPPOVM'>
<class 'pygsti.modelmembers.povms.basepovm._BasePOVM'>
Expand All @@ -100,15 +100,16 @@ def _build_torch_cache(model: ExplicitOpModel, layout):
"""

# Get the numerical representations
superket = rho.torch_base(require_grad=False, torch_handle=torch)
superops = [op.torch_base(require_grad=False, torch_handle=torch) for op in ops]
povm_mat = povm.torch_base(require_grad=False, torch_handle=torch)
# povm_mat = np.row_stack([effect.base for effect in effects])
# Right now I have a very awkward switch for gradients used in debugging.
require_grad = True
superket = rho.torch_base(require_grad, torch_handle=torch)[0]
superops = [op.torch_base(require_grad, torch_handle=torch)[0] for op in ops]
povm_mat = povm.torch_base(require_grad, torch_handle=torch)[0]

label_to_state[prep_label] = superket
for i, ol in enumerate(op_labels):
label_to_gate[ol] = superops[i]
label_to_povm[''.join(effect_labels)] = povm_mat
label_to_povm[povm_label] = povm_mat

return label_to_state, label_to_gate, label_to_povm

Expand All @@ -130,7 +131,7 @@ def _circuit_fill_probs_block(self, array_to_fill, circuit, outcomes, torch_cach
spc = next(iter(circuit.expand_instruments_and_separate_povm(self.model, outcomes)))
prep_label = spc.circuit_without_povm[0]
op_labels = spc.circuit_without_povm[1:]
povm_label = ''.join(spc.full_effect_labels)
povm_label = spc.povm_label

superket = l2state[prep_label]
superops = [l2gate[ol] for ol in op_labels]
Expand Down
2 changes: 1 addition & 1 deletion pygsti/modelmembers/operations/fulltpop.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def torch_base(self, require_grad: bool, torch_handle=None):
return t, [t_param]
else:
t = torch_handle.from_numpy(self._rep.base)
return t
return t, []

def deriv_wrt_params(self, wrt_filter=None):
"""
Expand Down
45 changes: 40 additions & 5 deletions pygsti/modelmembers/povms/tppovm.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,47 @@ def __reduce__(self):

return (TPPOVM, (effects, self.evotype, self.state_space, True),
{'_gpindices': self._gpindices, '_submember_rpindices': self._submember_rpindices})

@property
def dim(self):
effect = next(iter(self.values()))
return effect.dim

@property
def base(self):
effectreps = [effect._rep for effect in self.values()]
povm_mat = _np.row_stack([erep.state_rep.base for erep in effectreps])
return povm_mat

def torch_base(self, require_grad=False, torch_handle=None):
if torch_handle is None:
import torch as torch_handle
assert not require_grad
effectreps = [effect._rep for effect in self.values()]
povm_mat = _np.row_stack([erep.state_rep.base for erep in effectreps])
povm_mat = torch_handle.from_numpy(povm_mat)
return povm_mat
if not require_grad:
t = torch_handle.from_numpy(self.base)
return t, []
else:
assert self.complement_label is not None
complement_index = -1
for i,k in enumerate(self.keys()):
if k == self.complement_label:
complement_index = i
break
assert complement_index >= 0

num_effects = len(self)
if complement_index != num_effects - 1:
raise NotImplementedError()

not_comp_selector = _np.ones(shape=(num_effects,), dtype=bool)
not_comp_selector[complement_index] = False
dim = self.dim
first_basis_vec = torch_handle.zeros(size=(1, dim), dtype=torch_handle.double)
first_basis_vec[0,0] = dim ** 0.25

base = self.base
t_param = torch_handle.from_numpy(base[not_comp_selector, :])
t_param.requires_grad_(True)
t_func = first_basis_vec - t_param.sum(axis=0, keepdim=True)
t = torch_handle.row_stack((t_param, t_func))
return t, [t_param]

2 changes: 1 addition & 1 deletion pygsti/modelmembers/states/tpstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def torch_base(self, require_grad: bool, torch_handle=None):
return t, t_param
else:
t = torch_handle.from_numpy(self._rep.base)
return t
return t, []

def deriv_wrt_params(self, wrt_filter=None):
"""
Expand Down

0 comments on commit 0bea829

Please sign in to comment.