-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
804 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Copyright (C) 2023 Adam Lugowski. | ||
# Use of this source code is governed by the BSD 2-clause license found in the LICENSE.txt file. | ||
# SPDX-License-Identifier: BSD-2-Clause | ||
|
||
from typing import Any, Iterable, Tuple | ||
|
||
from . import Driver | ||
|
||
|
||
class TensorFlowDriver(Driver): | ||
@staticmethod | ||
def get_supported_types() -> Iterable[Tuple[str, bool]]: | ||
return [ | ||
("tf.Tensor", True), | ||
("tensorflow.python.framework.ops.Tensor", True), | ||
("tensorflow.python.framework.ops.EagerTensor", True), | ||
("tf.SparseTensor", True), | ||
("tensorflow.python.framework.sparse_tensor.SparseTensor", True), | ||
] | ||
|
||
@staticmethod | ||
def adapt(tensor: Any): | ||
import tensorflow as tf | ||
from .tensorflow_impl import TensorFlow1DTensorAdapter, TensorFlowFallbackAdapter | ||
from .tensorflow_impl import TensorFlow2DTensorDenseAdapter, TensorFlow2DTensorCooAdapter | ||
from .tensorflow_impl import TensorFlowNDTensorCooAdapter | ||
|
||
is_sparse = isinstance(tensor, tf.SparseTensor) | ||
ndims = len(tensor.shape) | ||
|
||
if ndims == 0: | ||
# single value | ||
return TensorFlowFallbackAdapter(tensor) | ||
elif ndims == 1: | ||
return TensorFlow1DTensorAdapter(tensor) | ||
elif ndims == 2: | ||
if is_sparse: | ||
return TensorFlow2DTensorCooAdapter(tensor) | ||
else: | ||
return TensorFlow2DTensorDenseAdapter(tensor) | ||
else: # ndims > 2 | ||
if is_sparse: | ||
return TensorFlowNDTensorCooAdapter(tensor) | ||
|
||
return TensorFlowFallbackAdapter(tensor) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# Copyright (C) 2023 Adam Lugowski. | ||
# Use of this source code is governed by the BSD 2-clause license found in the LICENSE.txt file. | ||
# SPDX-License-Identifier: BSD-2-Clause | ||
|
||
from typing import Any, Iterable, Union, Tuple | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
|
||
from . import describe, FallbackToNative, MatrixAdapterRow, MatrixAdapterCoo, TensorAdapterCooRow | ||
|
||
|
||
class TensorFlowBase: | ||
def __init__(self, tensor: Union[tf.Tensor, tf.SparseTensor]): | ||
self.tensor: Union[tf.Tensor, tf.SparseTensor] = tensor | ||
self.is_sparse = isinstance(tensor, tf.SparseTensor) | ||
|
||
def get_shape(self) -> tuple: | ||
return tuple(self.tensor.shape) | ||
|
||
def describe(self) -> str: | ||
parts = [] | ||
|
||
nnz = len(self.tensor.values) if self.is_sparse else None | ||
|
||
return describe(shape=tuple(self.tensor.shape), | ||
nnz=nnz, nz_type=self.tensor.dtype, | ||
layout="tf.SparseTensor" if self.is_sparse else "tf.Tensor", | ||
notes=", ".join(parts)) | ||
|
||
|
||
class TensorFlowFallbackAdapter(TensorFlowBase, MatrixAdapterRow): | ||
def __init__(self, tensor): | ||
MatrixAdapterRow.__init__(self) | ||
TensorFlowBase.__init__(self, tensor) | ||
|
||
def get_shape(self) -> tuple: | ||
return (1, ) | ||
|
||
def get_row(self, row_idx: int, col_range: Tuple[int, int]) -> Iterable[Any]: | ||
raise FallbackToNative | ||
|
||
|
||
class TensorFlow1DTensorAdapter(TensorFlowBase, MatrixAdapterRow): | ||
def __init__(self, tensor: Union[tf.Tensor, tf.SparseTensor]): | ||
assert len(tensor.shape) <= 1 | ||
MatrixAdapterRow.__init__(self) | ||
TensorFlowBase.__init__(self, tensor) | ||
self.row_labels = False | ||
|
||
def get_row(self, row_idx: int, col_range: Tuple[int, int]) -> Iterable[Any]: | ||
assert row_idx == 0 | ||
|
||
if not self.is_sparse: | ||
# dense | ||
return enumerate([v.numpy().item() for v in self.tensor[slice(*col_range)]], start=col_range[0]) | ||
|
||
indices = self.tensor.indices.numpy() | ||
values = self.tensor.values.numpy() | ||
mask = (col_range[0] <= indices) & (indices < col_range[1]) | ||
return [(i.item(), v.item()) for i, v in zip(indices[mask], values[mask.reshape(len(values),)])] | ||
|
||
|
||
class TensorFlow2DTensorDenseAdapter(TensorFlowBase, MatrixAdapterRow): | ||
def __init__(self, tensor: tf.Tensor): | ||
assert len(tensor.shape) == 2 | ||
MatrixAdapterRow.__init__(self) | ||
TensorFlowBase.__init__(self, tensor) | ||
|
||
def get_row(self, row_idx: int, col_range: Tuple[int, int]) -> Iterable[Any]: | ||
return enumerate([v.numpy().item() for v in self.tensor[row_idx, slice(*col_range)]], start=col_range[0]) | ||
|
||
|
||
class TensorFlow2DTensorCooAdapter(TensorFlowBase, MatrixAdapterCoo): | ||
def __init__(self, tensor: tf.SparseTensor): | ||
assert len(tensor.shape) == 2 | ||
MatrixAdapterCoo.__init__(self) | ||
TensorFlowBase.__init__(self, tensor) | ||
|
||
def get_coo(self, row_range: Tuple[int, int], col_range: Tuple[int, int]) -> Iterable[Tuple[int, int, Any]]: | ||
# rc_pairs = self.tensor.indices.numpy() | ||
rows, cols = np.split(self.tensor.indices.numpy(), 2, axis=1) | ||
rows = rows.flat | ||
cols = cols.flat | ||
values = self.tensor.values.numpy() | ||
row_mask = (row_range[0] <= rows) & (rows < row_range[1]) | ||
col_mask = (col_range[0] <= cols) & (cols < col_range[1]) | ||
mask = row_mask & col_mask | ||
return [(r.item(), c.item(), v.item()) | ||
for r, c, v in zip(rows[mask], cols[mask], values[mask])] | ||
|
||
|
||
class TensorFlowNDTensorCooAdapter(TensorFlowBase, TensorAdapterCooRow): | ||
def __init__(self, tensor: tf.SparseTensor): | ||
self.ndim = len(tensor.shape) | ||
self.nnz = len(tensor.values) | ||
|
||
TensorFlowBase.__init__(self, tensor) | ||
TensorAdapterCooRow.__init__(self) | ||
|
||
def get_shape(self) -> tuple: | ||
return self.nnz, self.ndim+1 | ||
|
||
def get_dense_row(self, row_idx: int, col_range: Tuple[int, int]) -> Iterable[Any]: | ||
indices = self.tensor.indices.numpy() | ||
end = min(col_range[1], self.ndim) | ||
ret = [indices[row_idx][i].item() for i in range(col_range[0], end)] | ||
|
||
if col_range[1] == self.ndim + 1: | ||
value = self.tensor.values[row_idx] | ||
value = value.numpy().item() | ||
ret.append(value) | ||
|
||
return ret |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Copyright (C) 2023 Adam Lugowski. | ||
# Use of this source code is governed by the BSD 2-clause license found in the LICENSE.txt file. | ||
# SPDX-License-Identifier: BSD-2-Clause | ||
|
||
import unittest | ||
|
||
try: | ||
import tensorflow as tf | ||
|
||
tf.random.set_seed(1234) | ||
except ImportError: | ||
tf = None | ||
|
||
from matrepr import to_html, to_latex, to_str | ||
|
||
|
||
def generate_fixed_value(m, n): | ||
row_factor = 10**(1+len(str(n))) | ||
data = [] | ||
for r in range(m): | ||
data.append([1] * n) | ||
for c in range(n): | ||
data[r][c] = (r+1)*row_factor + c | ||
|
||
return tf.constant(data, dtype=tf.int64), data | ||
|
||
|
||
@unittest.skipIf(tf is None, "TensorFlow not installed") | ||
class TensorFlowTests(unittest.TestCase): | ||
def setUp(self): | ||
rand1d = tf.random.uniform(shape=(50,)).numpy() | ||
rand1d[rand1d < 0.6] = 0 | ||
self.rand1d = tf.convert_to_tensor(rand1d) | ||
|
||
rand2d = tf.random.uniform(shape=(50, 30)).numpy() | ||
rand2d[rand2d < 0.6] = 0 | ||
self.rand2d = tf.convert_to_tensor(rand2d) | ||
|
||
rand3d = tf.random.uniform(shape=(50, 30, 10)).numpy() | ||
rand3d[rand3d < 0.6] = 0 | ||
self.rand3d = tf.convert_to_tensor(rand3d) | ||
|
||
self.tensors = [ | ||
(True, tf.constant(5)), | ||
(False, tf.constant([])), | ||
(False, tf.constant([1, 2, 3, 4])), | ||
(False, tf.constant([[1, 2], [1003, 1004]])), | ||
(False, tf.sparse.from_dense(tf.constant([[1, 2], [1003, 1004]]))), | ||
(False, self.rand1d), | ||
(False, tf.sparse.from_dense(self.rand1d)), | ||
(False, self.rand2d), | ||
(False, tf.sparse.from_dense(self.rand2d)), | ||
(True, self.rand3d), | ||
(False, tf.sparse.from_dense(self.rand3d)), | ||
(False, tf.sparse.SparseTensor(indices=[[0, 3], [2, 4]], values=[10, 20], dense_shape=[3, 10])), | ||
] | ||
|
||
def test_no_crash(self): | ||
for fallback_ok, tensor in self.tensors: | ||
res = to_str(tensor, title=True) | ||
self.assertGreater(len(res), 5) | ||
|
||
res = to_html(tensor, title=True) | ||
self.assertGreater(len(res), 5) | ||
if not fallback_ok: | ||
self.assertNotIn("<pre>", res) | ||
|
||
res = to_latex(tensor, title=True) | ||
self.assertGreater(len(res), 5) | ||
|
||
def test_contents_2d(self): | ||
source_tensor, data = generate_fixed_value(8, 8) | ||
for to_sparse in (False, True): | ||
tensor = tf.sparse.from_dense(source_tensor) if to_sparse else source_tensor | ||
|
||
res = to_html(tensor, notebook=False, max_rows=20, max_cols=20, title=True, indices=True) | ||
for row in data: | ||
for value in row: | ||
self.assertIn(f"<td>{value}</td>", res) | ||
|
||
trunc = to_html(tensor, notebook=False, max_rows=5, max_cols=5, title=True, indices=True) | ||
for value in (data[0][0], data[0][-1], data[-1][0], data[-1][-1]): | ||
self.assertIn(f"<td>{value}</td>", trunc) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |