Skip to content

Commit

Permalink
split insanely complicated casting function into simpler functions
Browse files Browse the repository at this point in the history
  • Loading branch information
rileyjmurray committed Sep 28, 2024
1 parent c864a5e commit bb55d13
Showing 1 changed file with 91 additions and 113 deletions.
204 changes: 91 additions & 113 deletions pygsti/baseobjs/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import itertools as _itertools
import warnings as _warnings
from functools import lru_cache
from typing import Union, Tuple, List

import numpy as _np
import scipy.sparse as _sps
Expand Down Expand Up @@ -154,126 +155,103 @@ class Basis(_NicelySerializable):
The "vectors" of this basis, always 1D (sparse or dense) arrays.
"""

# Implementation note: casting functions are classmethods, but current implementations
# could be static methods.

@classmethod
def cast(cls, name_or_basis_or_matrices, dim=None, sparse=None, classical_name='cl'):
"""
Convert various things that can describe a basis into a `Basis` object.
def cast_from_name_and_statespace(cls, name: str, state_space: _StateSpace, sparse=None, classical_name='cl'):
tpbBases = []
if len(state_space.tensor_product_blocks_labels) == 1 \
and len(state_space.tensor_product_blocks_labels[0]) == 1:
#Special case when we can actually pipe state_space to the BuiltinBasis constructor
lbl = state_space.tensor_product_blocks_labels[0][0]
nm = name if (state_space.label_type(lbl) == 'Q') else classical_name
tpbBases.append(BuiltinBasis(nm, state_space, sparse))
else:
#TODO: add methods to StateSpace that can extract a sub-*StateSpace* object for a given label.
for tpbLabels in state_space.tensor_product_blocks_labels:
if len(tpbLabels) == 1:
nm = name if (state_space.label_type(tpbLabels[0]) == 'Q') else classical_name
tpbBases.append(BuiltinBasis(nm, state_space.label_dimension(tpbLabels[0]), sparse))
else:
tpbBases.append(TensorProdBasis([
BuiltinBasis(name if (state_space.label_type(l) == 'Q') else classical_name,
state_space.label_dimension(l), sparse) for l in tpbLabels]))
if len(tpbBases) == 1:
return tpbBases[0]
else:
return DirectSumBasis(tpbBases)

Parameters
----------
name_or_basis_or_matrices : various
Can take on a variety of values to produce different types of bases:
- `None`: an empty `ExpicitBasis`
- `Basis`: checked with `dim` and `sparse` and passed through.
- `str`: `BuiltinBasis` or `DirectSumBasis` with the given name.
- `list`: an `ExplicitBasis` if given matrices/vectors or a
`DirectSumBasis` if given a `(name, dim)` pairs.
dim : int or StateSpace, optional
The dimension of the basis to create. Sometimes this can be
inferred based on `name_or_basis_or_matrices`, other times it must
be supplied. This is the dimension of the space that this basis
fully or partially spans. This is equal to the number of basis
elements in a "full" (ordinary) basis. When a `StateSpace`
object is given, a more detailed direct-sum-of-tensor-product-blocks
structure for the state space (rather than a single dimension) is
described, and a basis is produced for this space. For instance,
a `DirectSumBasis` basis of `TensorProdBasis` components can result
when there are multiple tensor-product blocks and these blocks
consist of multiple factors.
@classmethod
def cast_from_name_and_dims(cls, name: str, dim: Union[int,list,tuple], sparse=None):
if isinstance(dim, (list, tuple)): # list/tuple of block dimensions
tpbBases = []
for tpbDim in dim:
if isinstance(tpbDim, (list, tuple)): # list/tuple of tensor-product dimensions
tpbBases.append(
TensorProdBasis([BuiltinBasis(name, factorDim, sparse) for factorDim in tpbDim]))
else:
tpbBases.append(BuiltinBasis(name, tpbDim, sparse))

sparse : bool, optional
Whether the resulting basis should be "sparse", meaning that its
elements will be sparse rather than dense matrices.
if len(tpbBases) == 1:
return tpbBases[0]
else:
return DirectSumBasis(tpbBases)
else:
return BuiltinBasis(name, dim, sparse)

@classmethod
def cast_from_basis(cls, basis, dim=None, sparse=None):
#then just check to make sure consistent with `dim` & `sparse`
if dim is not None:
if isinstance(dim, _StateSpace):
state_space = dim
if hasattr(basis, 'state_space'): # TODO - should *all* basis objects have a state_space?
assert(state_space.is_compatible_with(basis.state_space)), \
"Basis object has incompatible state space: %s != %s" % (str(state_space),
str(basis.state_space))
else: # assume dim is an integer
assert(dim == basis.dim or dim == basis.elsize), \
"Basis object has unexpected dimension: %d != %d or %d" % (dim, basis.dim, basis.elsize)
if sparse is not None:
basis = basis.with_sparsity(sparse)
return basis

classical_name : str, optional
An alternate builtin basis name that should be used when
constructing the bases for the classical sectors of `dim`,
when `dim` is a `StateSpace` object.
@classmethod
def cast_from_arrays(cls, arrays, dim=None, sparse=None):
b = ExplicitBasis(arrays, sparse=sparse)
if dim is not None:
assert(dim == b.dim), "Created explicit basis has unexpected dimension: %d vs %d" % (dim, b.dim)
if sparse is not None:
assert(sparse == b.sparse), "Basis object has unexpected sparsity: %s" % (b.sparse)
return b

Returns
-------
Basis
"""
#print("DB: CAST = ",name_or_basis_or_matrices,dim)
from pygsti.baseobjs.statespace import StateSpace as _StateSpace
if name_or_basis_or_matrices is None: # special case of empty basis
return ExplicitBasis([], [], "*Empty*", "Empty (0-element) basis", False, sparse) # empty basis
elif isinstance(name_or_basis_or_matrices, Basis):
#then just check to make sure consistent with `dim` & `sparse`
basis = name_or_basis_or_matrices
if dim is not None:
if isinstance(dim, _StateSpace):
state_space = dim
if hasattr(basis, 'state_space'): # TODO - should *all* basis objects have a state_space?
assert(state_space.is_compatible_with(basis.state_space)), \
"Basis object has incompatible state space: %s != %s" % (str(state_space),
str(basis.state_space))
else: # assume dim is an integer
assert(dim == basis.dim or dim == basis.elsize), \
"Basis object has unexpected dimension: %d != %d or %d" % (dim, basis.dim, basis.elsize)
if sparse is not None:
basis = basis.with_sparsity(sparse)
return basis
elif isinstance(name_or_basis_or_matrices, str):
name = name_or_basis_or_matrices
@classmethod
def cast(cls, arg, dim=None, sparse=None, classical_name='cl'):
#print("DB: CAST = ",arg,dim)
if isinstance(arg, Basis):
return cls.cast_from_basis(arg, dim, sparse)
if isinstance(arg, str):
if isinstance(dim, _StateSpace):
state_space = dim
tpbBases = []
if len(state_space.tensor_product_blocks_labels) == 1 \
and len(state_space.tensor_product_blocks_labels[0]) == 1:
#Special case when we can actually pipe state_space to the BuiltinBasis constructor
lbl = state_space.tensor_product_blocks_labels[0][0]
nm = name if (state_space.label_type(lbl) == 'Q') else classical_name
tpbBases.append(BuiltinBasis(nm, state_space, sparse))
else:
#TODO: add methods to StateSpace that can extract a sub-*StateSpace* object for a given label.
for tpbLabels in state_space.tensor_product_blocks_labels:
if len(tpbLabels) == 1:
nm = name if (state_space.label_type(tpbLabels[0]) == 'Q') else classical_name
tpbBases.append(BuiltinBasis(nm, state_space.label_dimension(tpbLabels[0]), sparse))
else:
tpbBases.append(TensorProdBasis([
BuiltinBasis(name if (state_space.label_type(l) == 'Q') else classical_name,
state_space.label_dimension(l), sparse) for l in tpbLabels]))
if len(tpbBases) == 1:
return tpbBases[0]
else:
return DirectSumBasis(tpbBases)
elif isinstance(dim, (list, tuple)): # list/tuple of block dimensions
tpbBases = []
for tpbDim in dim:
if isinstance(tpbDim, (list, tuple)): # list/tuple of tensor-product dimensions
tpbBases.append(
TensorProdBasis([BuiltinBasis(name, factorDim, sparse) for factorDim in tpbDim]))
else:
tpbBases.append(BuiltinBasis(name, tpbDim, sparse))

if len(tpbBases) == 1:
return tpbBases[0]
else:
return DirectSumBasis(tpbBases)
else:
return BuiltinBasis(name, dim, sparse)
elif isinstance(name_or_basis_or_matrices, (list, tuple, _np.ndarray)):
# assume a list/array of matrices or (name, dim) pairs
if len(name_or_basis_or_matrices) == 0: # special case of empty basis
return ExplicitBasis([], [], "*Empty*", "Empty (0-element) basis", False, sparse) # empty basis
elif isinstance(name_or_basis_or_matrices[0], _np.ndarray):
b = ExplicitBasis(name_or_basis_or_matrices, sparse=sparse)
if dim is not None:
assert(dim == b.dim), "Created explicit basis has unexpected dimension: %d vs %d" % (dim, b.dim)
if sparse is not None:
assert(sparse == b.sparse), "Basis object has unexpected sparsity: %s" % (b.sparse)
return b
else: # assume els are (name, dim) pairs
compBases = [BuiltinBasis(subname, subdim, sparse)
for (subname, subdim) in name_or_basis_or_matrices]
return DirectSumBasis(compBases)
return cls.cast_from_name_and_statespace(arg, dim, sparse, classical_name)
return cls.cast_from_name_and_dims(arg, dim, sparse, classical_name)
if isinstance(arg, None) or (hasattr(arg,'__len__') and len(arg) == 0):
return ExplicitBasis([], [], "*Empty*", "Empty (0-element) basis", False, sparse)
# ^ The original implementation would return this value under two conditions.
# Either arg was None, or isinstance(arg,(tuple,list,ndarray)) and len(arg) == 0.
# We're just slightly relaxing the type requirement by using this check instead.

# At this point, original behavior would check that arg is a tuple, list, or ndarray.
# Instead, we'll just require that arg[0] is well-defined. This is enough to discern
# between the two cases we can still support.
if isinstance(arg[0], _np.ndarray):
return cls.cast_from_arrays(arg, dim, sparse)
if len(arg[0]) == 2:
compBases = [BuiltinBasis(subname, subdim, sparse) for (subname, subdim) in arg]
return DirectSumBasis(compBases)

raise ValueError("Can't cast %s to be a basis!" % str(type(arg)))

else:
raise ValueError("Can't cast %s to be a basis!" % str(type(name_or_basis_or_matrices)))

def __init__(self, name, longname, real, sparse):
super().__init__()
Expand Down

0 comments on commit bb55d13

Please sign in to comment.