Skip to content

Commit

Permalink
Merge pull request #20 from tenpy/refactor_zero_data
Browse files Browse the repository at this point in the history
refactor zero_data
  • Loading branch information
Jakob-Unfried authored Dec 5, 2024
2 parents beafd2c + ffe6fe1 commit 36546c5
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 39 deletions.
17 changes: 13 additions & 4 deletions cyten/backends/abelian.py
Original file line number Diff line number Diff line change
Expand Up @@ -1853,10 +1853,19 @@ def truncate_singular_values(self, S: DiagonalTensor, chi_max: int | None, chi_m
mask_data, small_leg = self.mask_from_block(keep, large_leg=S.leg)
return mask_data, small_leg, err, new_norm

def zero_data(self, codomain: ProductSpace, domain: ProductSpace, dtype: Dtype, device: str
) -> AbelianBackendData:
block_inds = np.zeros((0, codomain.num_spaces + domain.num_spaces), dtype=int)
return AbelianBackendData(dtype, device, blocks=[], block_inds=block_inds, is_sorted=True)
def zero_data(self, codomain: ProductSpace, domain: ProductSpace, dtype: Dtype, device: str,
all_blocks: bool = False) -> AbelianBackendData:
if not all_blocks:
block_inds = np.zeros((0, codomain.num_spaces + domain.num_spaces), dtype=int)
return AbelianBackendData(dtype, device, blocks=[], block_inds=block_inds, is_sorted=True)

block_inds = _valid_block_inds(codomain=codomain, domain=domain)
zero_blocks = []
for idcs in block_inds:
shape = [leg.multiplicities[i]
for i, leg in zip(idcs, conventional_leg_order(codomain, domain))]
zero_blocks.append(self.block_backend.zero_block(shape, dtype=dtype))
return AbelianBackendData(dtype, device, zero_blocks, block_inds, is_sorted=True)

def zero_diagonal_data(self, co_domain: ProductSpace, dtype: Dtype, device: str
) -> DiagonalData:
Expand Down
9 changes: 6 additions & 3 deletions cyten/backends/abstract_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,9 +765,12 @@ def _truncate_singular_values_selection(self, S: np.ndarray, qdims: np.ndarray |
return mask, err, new_norm

@abstractmethod
def zero_data(self, codomain: ProductSpace, domain: ProductSpace, dtype: Dtype, device: str
) -> Data:
"""Data for a zero tensor"""
def zero_data(self, codomain: ProductSpace, domain: ProductSpace, dtype: Dtype, device: str,
all_blocks: bool = False) -> Data:
"""Data for a zero tensor. Explicitly constructs the zero blocks corresponding to all
consistent sectors of the correct shape if `all_blocks == True`. Otherwise, the blocks
in the returned `Data` is an empty list.
"""
...

@abstractmethod
Expand Down
54 changes: 25 additions & 29 deletions cyten/backends/fusion_tree_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,29 +240,6 @@ def discard_zero_blocks(self, backend: BlockBackend, eps: float) -> None:
self.blocks = [self.blocks[i] for i in keep]
self.block_inds = self.block_inds[keep]

@classmethod
def _zero_data(cls, codomain: ProductSpace, domain: ProductSpace, backend: BlockBackend,
dtype: Dtype, device: str) -> FusionTreeData:
"""Return `FusionTreeData` consistent with `codomain` and `domain`, where all blocks
(corresponding to all allowed coupled sectors) are zero. These zero blocks are stored
such that new values can be assigned step by step. Note however that zero blocks are
generally not stored in tensors.
"""
# TODO leave this here? would be more consistet in FusionTreeBackend
block_shapes = []
block_inds = []
for j, coupled in enumerate(domain.sectors):
i = codomain.sectors_where(coupled)
if i == None:
continue
shp = (block_size(codomain, coupled), block_size(domain, coupled))
block_shapes.append(shp)
block_inds.append([i, j])
block_inds = np.array(block_inds)

zero_blocks = [backend.zero_block(block_shape, dtype=dtype) for block_shape in block_shapes]
return cls(block_inds, zero_blocks, dtype=dtype, device=device, is_sorted=True)


class FusionTreeBackend(TensorBackend):
"""`ProductSpace`s on the individual legs of the tensors are not supported, only
Expand Down Expand Up @@ -1278,11 +1255,30 @@ def truncate_singular_values(self, S: DiagonalTensor, chi_max: int | None, chi_m
is_dual=S.leg.is_bra_space)
return mask_data, small_leg, err, new_norm

def zero_data(self, codomain: ProductSpace, domain: ProductSpace, dtype: Dtype, device: str,
all_blocks: bool = False) -> FusionTreeData:
if not all_blocks:
return FusionTreeData(block_inds=np.zeros((0, 2), int), blocks=[], dtype=dtype,
device=device)

block_shapes = []
block_inds = []
for j, coupled in enumerate(domain.sectors):
i = codomain.sectors_where(coupled)
if i == None:
continue
shp = (block_size(codomain, coupled), block_size(domain, coupled))
block_shapes.append(shp)
block_inds.append([i, j])

def zero_data(self, codomain: ProductSpace, domain: ProductSpace, dtype: Dtype, device: str
) -> FusionTreeData:
return FusionTreeData(block_inds=np.zeros((0, 2), int), blocks=[], dtype=dtype,
device=device)
if len(block_inds) == 0:
return FusionTreeData(block_inds=np.zeros((0, 2), int), blocks=[], dtype=dtype,
device=device)

block_inds = np.array(block_inds)
zero_blocks = [self.block_backend.zero_block(block_shape, dtype=dtype)
for block_shape in block_shapes]
return FusionTreeData(block_inds, zero_blocks, dtype=dtype, device=device, is_sorted=True)

def zero_diagonal_data(self, co_domain: ProductSpace, dtype: Dtype, device: str
) -> DiagonalData:
Expand Down Expand Up @@ -1784,8 +1780,8 @@ def _apply_two_trees_in_keys(self, ten: SymmetricTensor, new_codomain: ProductSp
new_domain: ProductSpace, block_axes_permutation: list[int],
) -> FusionTreeData:
backend = ten.backend.block_backend
new_data = FusionTreeData._zero_data(new_codomain, new_domain, backend, Dtype.complex128,
device=ten.data.device)
new_data = ten.backend.zero_data(new_codomain, new_domain, Dtype.complex128,
device=ten.data.device, all_blocks=True)

for alpha_tree, beta_tree, tree_block in _tree_block_iter(ten):
contributions = self[(alpha_tree, beta_tree)]
Expand Down
3 changes: 2 additions & 1 deletion cyten/backends/no_symmetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,8 @@ def truncate_singular_values(self, S: DiagonalTensor, chi_max: int | None, chi_m
)
return mask_data, new_leg, err, new_norm

def zero_data(self, codomain: ProductSpace, domain: ProductSpace, dtype: Dtype, device: str):
def zero_data(self, codomain: ProductSpace, domain: ProductSpace, dtype: Dtype, device: str,
all_blocks: bool = False) -> Data:
return self.block_backend.zero_block(
shape=[l.dim for l in conventional_leg_order(codomain, domain)],
dtype=dtype, device=device
Expand Down
4 changes: 2 additions & 2 deletions tests/pytest/linalg/test_backend_nonabelian.py
Original file line number Diff line number Diff line change
Expand Up @@ -1892,8 +1892,8 @@ def cross_check_single_b_symbol(ten: SymmetricTensor, bend_up: bool
new_codomain = [new_space1, new_space2][bend_up]
new_domain = [new_space1, new_space2][not bend_up]

new_data = ftb.FusionTreeData._zero_data(new_codomain, new_domain, block_backend,
dtype=Dtype.complex128, device=device)
new_data = ten.backend.zero_data(new_codomain, new_domain, dtype=Dtype.complex128,
device=device, all_blocks=True)

for alpha_tree, beta_tree, tree_block in ftb._tree_block_iter(ten):
modified_shape = [ten.codomain[i].sector_multiplicity(sec)
Expand Down

0 comments on commit 36546c5

Please sign in to comment.