From bb55d13205acc3b6a4bf4b7ccbbc793fd71a57ea Mon Sep 17 00:00:00 2001 From: Riley Murray Date: Sat, 28 Sep 2024 09:59:50 -0400 Subject: [PATCH] split insanely complicated casting function into simpler functions --- pygsti/baseobjs/basis.py | 204 +++++++++++++++++---------------------- 1 file changed, 91 insertions(+), 113 deletions(-) diff --git a/pygsti/baseobjs/basis.py b/pygsti/baseobjs/basis.py index f4b329e7a..438f7f6db 100644 --- a/pygsti/baseobjs/basis.py +++ b/pygsti/baseobjs/basis.py @@ -14,6 +14,7 @@ import itertools as _itertools import warnings as _warnings from functools import lru_cache +from typing import Union, Tuple, List import numpy as _np import scipy.sparse as _sps @@ -154,126 +155,103 @@ class Basis(_NicelySerializable): The "vectors" of this basis, always 1D (sparse or dense) arrays. """ + # Implementation note: casting functions are classmethods, but current implementations + # could be static methods. + @classmethod - def cast(cls, name_or_basis_or_matrices, dim=None, sparse=None, classical_name='cl'): - """ - Convert various things that can describe a basis into a `Basis` object. + def cast_from_name_and_statespace(cls, name: str, state_space: _StateSpace, sparse=None, classical_name='cl'): + tpbBases = [] + if len(state_space.tensor_product_blocks_labels) == 1 \ + and len(state_space.tensor_product_blocks_labels[0]) == 1: + #Special case when we can actually pipe state_space to the BuiltinBasis constructor + lbl = state_space.tensor_product_blocks_labels[0][0] + nm = name if (state_space.label_type(lbl) == 'Q') else classical_name + tpbBases.append(BuiltinBasis(nm, state_space, sparse)) + else: + #TODO: add methods to StateSpace that can extract a sub-*StateSpace* object for a given label. + for tpbLabels in state_space.tensor_product_blocks_labels: + if len(tpbLabels) == 1: + nm = name if (state_space.label_type(tpbLabels[0]) == 'Q') else classical_name + tpbBases.append(BuiltinBasis(nm, state_space.label_dimension(tpbLabels[0]), sparse)) + else: + tpbBases.append(TensorProdBasis([ + BuiltinBasis(name if (state_space.label_type(l) == 'Q') else classical_name, + state_space.label_dimension(l), sparse) for l in tpbLabels])) + if len(tpbBases) == 1: + return tpbBases[0] + else: + return DirectSumBasis(tpbBases) - Parameters - ---------- - name_or_basis_or_matrices : various - Can take on a variety of values to produce different types of bases: - - - `None`: an empty `ExpicitBasis` - - `Basis`: checked with `dim` and `sparse` and passed through. - - `str`: `BuiltinBasis` or `DirectSumBasis` with the given name. - - `list`: an `ExplicitBasis` if given matrices/vectors or a - `DirectSumBasis` if given a `(name, dim)` pairs. - - dim : int or StateSpace, optional - The dimension of the basis to create. Sometimes this can be - inferred based on `name_or_basis_or_matrices`, other times it must - be supplied. This is the dimension of the space that this basis - fully or partially spans. This is equal to the number of basis - elements in a "full" (ordinary) basis. When a `StateSpace` - object is given, a more detailed direct-sum-of-tensor-product-blocks - structure for the state space (rather than a single dimension) is - described, and a basis is produced for this space. For instance, - a `DirectSumBasis` basis of `TensorProdBasis` components can result - when there are multiple tensor-product blocks and these blocks - consist of multiple factors. + @classmethod + def cast_from_name_and_dims(cls, name: str, dim: Union[int,list,tuple], sparse=None): + if isinstance(dim, (list, tuple)): # list/tuple of block dimensions + tpbBases = [] + for tpbDim in dim: + if isinstance(tpbDim, (list, tuple)): # list/tuple of tensor-product dimensions + tpbBases.append( + TensorProdBasis([BuiltinBasis(name, factorDim, sparse) for factorDim in tpbDim])) + else: + tpbBases.append(BuiltinBasis(name, tpbDim, sparse)) - sparse : bool, optional - Whether the resulting basis should be "sparse", meaning that its - elements will be sparse rather than dense matrices. + if len(tpbBases) == 1: + return tpbBases[0] + else: + return DirectSumBasis(tpbBases) + else: + return BuiltinBasis(name, dim, sparse) + + @classmethod + def cast_from_basis(cls, basis, dim=None, sparse=None): + #then just check to make sure consistent with `dim` & `sparse` + if dim is not None: + if isinstance(dim, _StateSpace): + state_space = dim + if hasattr(basis, 'state_space'): # TODO - should *all* basis objects have a state_space? + assert(state_space.is_compatible_with(basis.state_space)), \ + "Basis object has incompatible state space: %s != %s" % (str(state_space), + str(basis.state_space)) + else: # assume dim is an integer + assert(dim == basis.dim or dim == basis.elsize), \ + "Basis object has unexpected dimension: %d != %d or %d" % (dim, basis.dim, basis.elsize) + if sparse is not None: + basis = basis.with_sparsity(sparse) + return basis - classical_name : str, optional - An alternate builtin basis name that should be used when - constructing the bases for the classical sectors of `dim`, - when `dim` is a `StateSpace` object. + @classmethod + def cast_from_arrays(cls, arrays, dim=None, sparse=None): + b = ExplicitBasis(arrays, sparse=sparse) + if dim is not None: + assert(dim == b.dim), "Created explicit basis has unexpected dimension: %d vs %d" % (dim, b.dim) + if sparse is not None: + assert(sparse == b.sparse), "Basis object has unexpected sparsity: %s" % (b.sparse) + return b - Returns - ------- - Basis - """ - #print("DB: CAST = ",name_or_basis_or_matrices,dim) - from pygsti.baseobjs.statespace import StateSpace as _StateSpace - if name_or_basis_or_matrices is None: # special case of empty basis - return ExplicitBasis([], [], "*Empty*", "Empty (0-element) basis", False, sparse) # empty basis - elif isinstance(name_or_basis_or_matrices, Basis): - #then just check to make sure consistent with `dim` & `sparse` - basis = name_or_basis_or_matrices - if dim is not None: - if isinstance(dim, _StateSpace): - state_space = dim - if hasattr(basis, 'state_space'): # TODO - should *all* basis objects have a state_space? - assert(state_space.is_compatible_with(basis.state_space)), \ - "Basis object has incompatible state space: %s != %s" % (str(state_space), - str(basis.state_space)) - else: # assume dim is an integer - assert(dim == basis.dim or dim == basis.elsize), \ - "Basis object has unexpected dimension: %d != %d or %d" % (dim, basis.dim, basis.elsize) - if sparse is not None: - basis = basis.with_sparsity(sparse) - return basis - elif isinstance(name_or_basis_or_matrices, str): - name = name_or_basis_or_matrices + @classmethod + def cast(cls, arg, dim=None, sparse=None, classical_name='cl'): + #print("DB: CAST = ",arg,dim) + if isinstance(arg, Basis): + return cls.cast_from_basis(arg, dim, sparse) + if isinstance(arg, str): if isinstance(dim, _StateSpace): - state_space = dim - tpbBases = [] - if len(state_space.tensor_product_blocks_labels) == 1 \ - and len(state_space.tensor_product_blocks_labels[0]) == 1: - #Special case when we can actually pipe state_space to the BuiltinBasis constructor - lbl = state_space.tensor_product_blocks_labels[0][0] - nm = name if (state_space.label_type(lbl) == 'Q') else classical_name - tpbBases.append(BuiltinBasis(nm, state_space, sparse)) - else: - #TODO: add methods to StateSpace that can extract a sub-*StateSpace* object for a given label. - for tpbLabels in state_space.tensor_product_blocks_labels: - if len(tpbLabels) == 1: - nm = name if (state_space.label_type(tpbLabels[0]) == 'Q') else classical_name - tpbBases.append(BuiltinBasis(nm, state_space.label_dimension(tpbLabels[0]), sparse)) - else: - tpbBases.append(TensorProdBasis([ - BuiltinBasis(name if (state_space.label_type(l) == 'Q') else classical_name, - state_space.label_dimension(l), sparse) for l in tpbLabels])) - if len(tpbBases) == 1: - return tpbBases[0] - else: - return DirectSumBasis(tpbBases) - elif isinstance(dim, (list, tuple)): # list/tuple of block dimensions - tpbBases = [] - for tpbDim in dim: - if isinstance(tpbDim, (list, tuple)): # list/tuple of tensor-product dimensions - tpbBases.append( - TensorProdBasis([BuiltinBasis(name, factorDim, sparse) for factorDim in tpbDim])) - else: - tpbBases.append(BuiltinBasis(name, tpbDim, sparse)) - - if len(tpbBases) == 1: - return tpbBases[0] - else: - return DirectSumBasis(tpbBases) - else: - return BuiltinBasis(name, dim, sparse) - elif isinstance(name_or_basis_or_matrices, (list, tuple, _np.ndarray)): - # assume a list/array of matrices or (name, dim) pairs - if len(name_or_basis_or_matrices) == 0: # special case of empty basis - return ExplicitBasis([], [], "*Empty*", "Empty (0-element) basis", False, sparse) # empty basis - elif isinstance(name_or_basis_or_matrices[0], _np.ndarray): - b = ExplicitBasis(name_or_basis_or_matrices, sparse=sparse) - if dim is not None: - assert(dim == b.dim), "Created explicit basis has unexpected dimension: %d vs %d" % (dim, b.dim) - if sparse is not None: - assert(sparse == b.sparse), "Basis object has unexpected sparsity: %s" % (b.sparse) - return b - else: # assume els are (name, dim) pairs - compBases = [BuiltinBasis(subname, subdim, sparse) - for (subname, subdim) in name_or_basis_or_matrices] - return DirectSumBasis(compBases) + return cls.cast_from_name_and_statespace(arg, dim, sparse, classical_name) + return cls.cast_from_name_and_dims(arg, dim, sparse, classical_name) + if isinstance(arg, None) or (hasattr(arg,'__len__') and len(arg) == 0): + return ExplicitBasis([], [], "*Empty*", "Empty (0-element) basis", False, sparse) + # ^ The original implementation would return this value under two conditions. + # Either arg was None, or isinstance(arg,(tuple,list,ndarray)) and len(arg) == 0. + # We're just slightly relaxing the type requirement by using this check instead. + + # At this point, original behavior would check that arg is a tuple, list, or ndarray. + # Instead, we'll just require that arg[0] is well-defined. This is enough to discern + # between the two cases we can still support. + if isinstance(arg[0], _np.ndarray): + return cls.cast_from_arrays(arg, dim, sparse) + if len(arg[0]) == 2: + compBases = [BuiltinBasis(subname, subdim, sparse) for (subname, subdim) in arg] + return DirectSumBasis(compBases) + + raise ValueError("Can't cast %s to be a basis!" % str(type(arg))) - else: - raise ValueError("Can't cast %s to be a basis!" % str(type(name_or_basis_or_matrices))) def __init__(self, name, longname, real, sparse): super().__init__()