diff --git a/pygsti/forwardsims/torchfwdsim.py b/pygsti/forwardsims/torchfwdsim.py index 649aa12cf..9d96c770c 100644 --- a/pygsti/forwardsims/torchfwdsim.py +++ b/pygsti/forwardsims/torchfwdsim.py @@ -68,7 +68,7 @@ def _build_torch_cache(model: ExplicitOpModel, layout): prep_label = spc.circuit_without_povm[0] op_labels = spc.circuit_without_povm[1:] - effect_labels = spc.full_effect_labels + povm_label = spc.povm_label rho = model.circuit_layer_operator(prep_label, typ='prep') """ ^ @@ -89,7 +89,7 @@ def _build_torch_cache(model: ExplicitOpModel, layout): """ - povm = model.circuit_layer_operator(spc.povm_label, 'povm') + povm = model.circuit_layer_operator(povm_label, 'povm') """ @@ -100,15 +100,16 @@ def _build_torch_cache(model: ExplicitOpModel, layout): """ # Get the numerical representations - superket = rho.torch_base(require_grad=False, torch_handle=torch) - superops = [op.torch_base(require_grad=False, torch_handle=torch) for op in ops] - povm_mat = povm.torch_base(require_grad=False, torch_handle=torch) - # povm_mat = np.row_stack([effect.base for effect in effects]) + # Right now I have a very awkward switch for gradients used in debugging. + require_grad = True + superket = rho.torch_base(require_grad, torch_handle=torch)[0] + superops = [op.torch_base(require_grad, torch_handle=torch)[0] for op in ops] + povm_mat = povm.torch_base(require_grad, torch_handle=torch)[0] label_to_state[prep_label] = superket for i, ol in enumerate(op_labels): label_to_gate[ol] = superops[i] - label_to_povm[''.join(effect_labels)] = povm_mat + label_to_povm[povm_label] = povm_mat return label_to_state, label_to_gate, label_to_povm @@ -130,7 +131,7 @@ def _circuit_fill_probs_block(self, array_to_fill, circuit, outcomes, torch_cach spc = next(iter(circuit.expand_instruments_and_separate_povm(self.model, outcomes))) prep_label = spc.circuit_without_povm[0] op_labels = spc.circuit_without_povm[1:] - povm_label = ''.join(spc.full_effect_labels) + povm_label = spc.povm_label superket = l2state[prep_label] superops = [l2gate[ol] for ol in op_labels] diff --git a/pygsti/modelmembers/operations/fulltpop.py b/pygsti/modelmembers/operations/fulltpop.py index 62b5bba01..78335f9ce 100644 --- a/pygsti/modelmembers/operations/fulltpop.py +++ b/pygsti/modelmembers/operations/fulltpop.py @@ -167,7 +167,7 @@ def torch_base(self, require_grad: bool, torch_handle=None): return t, [t_param] else: t = torch_handle.from_numpy(self._rep.base) - return t + return t, [] def deriv_wrt_params(self, wrt_filter=None): """ diff --git a/pygsti/modelmembers/povms/tppovm.py b/pygsti/modelmembers/povms/tppovm.py index 6ee3a26c5..e8d6fdbff 100644 --- a/pygsti/modelmembers/povms/tppovm.py +++ b/pygsti/modelmembers/povms/tppovm.py @@ -57,12 +57,47 @@ def __reduce__(self): return (TPPOVM, (effects, self.evotype, self.state_space, True), {'_gpindices': self._gpindices, '_submember_rpindices': self._submember_rpindices}) + + @property + def dim(self): + effect = next(iter(self.values())) + return effect.dim + + @property + def base(self): + effectreps = [effect._rep for effect in self.values()] + povm_mat = _np.row_stack([erep.state_rep.base for erep in effectreps]) + return povm_mat def torch_base(self, require_grad=False, torch_handle=None): if torch_handle is None: import torch as torch_handle - assert not require_grad - effectreps = [effect._rep for effect in self.values()] - povm_mat = _np.row_stack([erep.state_rep.base for erep in effectreps]) - povm_mat = torch_handle.from_numpy(povm_mat) - return povm_mat + if not require_grad: + t = torch_handle.from_numpy(self.base) + return t, [] + else: + assert self.complement_label is not None + complement_index = -1 + for i,k in enumerate(self.keys()): + if k == self.complement_label: + complement_index = i + break + assert complement_index >= 0 + + num_effects = len(self) + if complement_index != num_effects - 1: + raise NotImplementedError() + + not_comp_selector = _np.ones(shape=(num_effects,), dtype=bool) + not_comp_selector[complement_index] = False + dim = self.dim + first_basis_vec = torch_handle.zeros(size=(1, dim), dtype=torch_handle.double) + first_basis_vec[0,0] = dim ** 0.25 + + base = self.base + t_param = torch_handle.from_numpy(base[not_comp_selector, :]) + t_param.requires_grad_(True) + t_func = first_basis_vec - t_param.sum(axis=0, keepdim=True) + t = torch_handle.row_stack((t_param, t_func)) + return t, [t_param] + diff --git a/pygsti/modelmembers/states/tpstate.py b/pygsti/modelmembers/states/tpstate.py index 0054ac705..02c230573 100644 --- a/pygsti/modelmembers/states/tpstate.py +++ b/pygsti/modelmembers/states/tpstate.py @@ -169,7 +169,7 @@ def torch_base(self, require_grad: bool, torch_handle=None): return t, t_param else: t = torch_handle.from_numpy(self._rep.base) - return t + return t, [] def deriv_wrt_params(self, wrt_filter=None): """