Skip to content

Commit

Permalink
Add TensorFlow support
Browse files Browse the repository at this point in the history
  • Loading branch information
alugowski committed Sep 7, 2023
1 parent 7c3f85b commit 3de4f3f
Show file tree
Hide file tree
Showing 8 changed files with 804 additions and 1 deletion.
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ jobs:
pip install python-graphblas
pip install sparse
pip install torch
pip install tensorflow
- name: Python Test without Jupyter
if: ${{ !contains(matrix.python-version, 'pypy') }} # no scipy wheels for pypy
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ MatRepr formats matrices to HTML, string, and LaTeX, with Jupyter integration.
* **SciPy** - sparse matrices and arrays like `csr_matrix` and `coo_array`
* **NumPy** - `ndarray`
* **[PyTorch](https://pytorch.org/docs/stable/sparse.html)** - dense and sparse `torch.Tensor` [(demo)](doc/demo-pytorch.ipynb)
* **[TensorFlow](https://www.tensorflow.org/guide/sparse_tensor)** - `tf.Tensor` and `tf.SparseTensor` [(demo)](doc/demo-tensorflow.ipynb)
* **[Python-graphblas](https://github.com/python-graphblas/python-graphblas)** - `gb.Matrix` and `gb.Vector` [(demo)](doc/demo-python-graphblas.ipynb)
* **[PyData/Sparse](https://sparse.pydata.org/)** - `COO`, `DOK`, `GCXS` [(demo)](doc/demo-pydata-sparse.ipynb)
* `list`, `tuple`, including multi-dimensional and ragged
Expand Down
549 changes: 549 additions & 0 deletions doc/demo-tensorflow.ipynb

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions matrepr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ def _register_bundled():
from .adapters.torch_driver import PyTorchDriver
register_driver(PyTorchDriver)

from .adapters.tensorflow_driver import TensorFlowDriver
register_driver(TensorFlowDriver)

from .adapters.list_like import ListDriver
register_driver(ListDriver)

Expand Down
45 changes: 45 additions & 0 deletions matrepr/adapters/tensorflow_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (C) 2023 Adam Lugowski.
# Use of this source code is governed by the BSD 2-clause license found in the LICENSE.txt file.
# SPDX-License-Identifier: BSD-2-Clause

from typing import Any, Iterable, Tuple

from . import Driver


class TensorFlowDriver(Driver):
@staticmethod
def get_supported_types() -> Iterable[Tuple[str, bool]]:
return [
("tf.Tensor", True),
("tensorflow.python.framework.ops.Tensor", True),
("tensorflow.python.framework.ops.EagerTensor", True),
("tf.SparseTensor", True),
("tensorflow.python.framework.sparse_tensor.SparseTensor", True),
]

@staticmethod
def adapt(tensor: Any):
import tensorflow as tf
from .tensorflow_impl import TensorFlow1DTensorAdapter, TensorFlowFallbackAdapter
from .tensorflow_impl import TensorFlow2DTensorDenseAdapter, TensorFlow2DTensorCooAdapter
from .tensorflow_impl import TensorFlowNDTensorCooAdapter

is_sparse = isinstance(tensor, tf.SparseTensor)
ndims = len(tensor.shape)

if ndims == 0:
# single value
return TensorFlowFallbackAdapter(tensor)
elif ndims == 1:
return TensorFlow1DTensorAdapter(tensor)
elif ndims == 2:
if is_sparse:
return TensorFlow2DTensorCooAdapter(tensor)
else:
return TensorFlow2DTensorDenseAdapter(tensor)
else: # ndims > 2
if is_sparse:
return TensorFlowNDTensorCooAdapter(tensor)

return TensorFlowFallbackAdapter(tensor)
114 changes: 114 additions & 0 deletions matrepr/adapters/tensorflow_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (C) 2023 Adam Lugowski.
# Use of this source code is governed by the BSD 2-clause license found in the LICENSE.txt file.
# SPDX-License-Identifier: BSD-2-Clause

from typing import Any, Iterable, Union, Tuple

import numpy as np
import tensorflow as tf

from . import describe, FallbackToNative, MatrixAdapterRow, MatrixAdapterCoo, TensorAdapterCooRow


class TensorFlowBase:
def __init__(self, tensor: Union[tf.Tensor, tf.SparseTensor]):
self.tensor: Union[tf.Tensor, tf.SparseTensor] = tensor
self.is_sparse = isinstance(tensor, tf.SparseTensor)

def get_shape(self) -> tuple:
return tuple(self.tensor.shape)

def describe(self) -> str:
parts = []

nnz = len(self.tensor.values) if self.is_sparse else None

return describe(shape=tuple(self.tensor.shape),
nnz=nnz, nz_type=self.tensor.dtype,
layout="tf.SparseTensor" if self.is_sparse else "tf.Tensor",
notes=", ".join(parts))


class TensorFlowFallbackAdapter(TensorFlowBase, MatrixAdapterRow):
def __init__(self, tensor):
MatrixAdapterRow.__init__(self)
TensorFlowBase.__init__(self, tensor)

def get_shape(self) -> tuple:
return (1, )

def get_row(self, row_idx: int, col_range: Tuple[int, int]) -> Iterable[Any]:
raise FallbackToNative


class TensorFlow1DTensorAdapter(TensorFlowBase, MatrixAdapterRow):
def __init__(self, tensor: Union[tf.Tensor, tf.SparseTensor]):
assert len(tensor.shape) <= 1
MatrixAdapterRow.__init__(self)
TensorFlowBase.__init__(self, tensor)
self.row_labels = False

def get_row(self, row_idx: int, col_range: Tuple[int, int]) -> Iterable[Any]:
assert row_idx == 0

if not self.is_sparse:
# dense
return enumerate([v.numpy().item() for v in self.tensor[slice(*col_range)]], start=col_range[0])

indices = self.tensor.indices.numpy()
values = self.tensor.values.numpy()
mask = (col_range[0] <= indices) & (indices < col_range[1])
return [(i.item(), v.item()) for i, v in zip(indices[mask], values[mask.reshape(len(values),)])]


class TensorFlow2DTensorDenseAdapter(TensorFlowBase, MatrixAdapterRow):
def __init__(self, tensor: tf.Tensor):
assert len(tensor.shape) == 2
MatrixAdapterRow.__init__(self)
TensorFlowBase.__init__(self, tensor)

def get_row(self, row_idx: int, col_range: Tuple[int, int]) -> Iterable[Any]:
return enumerate([v.numpy().item() for v in self.tensor[row_idx, slice(*col_range)]], start=col_range[0])


class TensorFlow2DTensorCooAdapter(TensorFlowBase, MatrixAdapterCoo):
def __init__(self, tensor: tf.SparseTensor):
assert len(tensor.shape) == 2
MatrixAdapterCoo.__init__(self)
TensorFlowBase.__init__(self, tensor)

def get_coo(self, row_range: Tuple[int, int], col_range: Tuple[int, int]) -> Iterable[Tuple[int, int, Any]]:
# rc_pairs = self.tensor.indices.numpy()
rows, cols = np.split(self.tensor.indices.numpy(), 2, axis=1)
rows = rows.flat
cols = cols.flat
values = self.tensor.values.numpy()
row_mask = (row_range[0] <= rows) & (rows < row_range[1])
col_mask = (col_range[0] <= cols) & (cols < col_range[1])
mask = row_mask & col_mask
return [(r.item(), c.item(), v.item())
for r, c, v in zip(rows[mask], cols[mask], values[mask])]


class TensorFlowNDTensorCooAdapter(TensorFlowBase, TensorAdapterCooRow):
def __init__(self, tensor: tf.SparseTensor):
self.ndim = len(tensor.shape)
self.nnz = len(tensor.values)

TensorFlowBase.__init__(self, tensor)
TensorAdapterCooRow.__init__(self)

def get_shape(self) -> tuple:
return self.nnz, self.ndim+1

def get_dense_row(self, row_idx: int, col_range: Tuple[int, int]) -> Iterable[Any]:
indices = self.tensor.indices.numpy()
end = min(col_range[1], self.ndim)
ret = [indices[row_idx][i].item() for i in range(col_range[0], end)]

if col_range[1] == self.ndim + 1:
value = self.tensor.values[row_idx]
value = value.numpy().item()
ret.append(value)

return ret
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ keywords = [
"scipy",
"pydata",
"graphblas",
"torch",
"pytorch",
"tensorflow",
]

[project.urls]
Expand All @@ -40,4 +43,4 @@ repository = "https://github.com/alugowski/matrepr"

[project.optional-dependencies]
test = ["pytest", "html5lib", "scipy"]
supported = ["scipy", "numpy", "python-graphblas", "sparse", "torch"]
supported = ["scipy", "numpy", "python-graphblas", "sparse", "torch", "tensorflow"]
87 changes: 87 additions & 0 deletions tests/test_tensorflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (C) 2023 Adam Lugowski.
# Use of this source code is governed by the BSD 2-clause license found in the LICENSE.txt file.
# SPDX-License-Identifier: BSD-2-Clause

import unittest

try:
import tensorflow as tf

tf.random.set_seed(1234)
except ImportError:
tf = None

from matrepr import to_html, to_latex, to_str


def generate_fixed_value(m, n):
row_factor = 10**(1+len(str(n)))
data = []
for r in range(m):
data.append([1] * n)
for c in range(n):
data[r][c] = (r+1)*row_factor + c

return tf.constant(data, dtype=tf.int64), data


@unittest.skipIf(tf is None, "TensorFlow not installed")
class TensorFlowTests(unittest.TestCase):
def setUp(self):
rand1d = tf.random.uniform(shape=(50,)).numpy()
rand1d[rand1d < 0.6] = 0
self.rand1d = tf.convert_to_tensor(rand1d)

rand2d = tf.random.uniform(shape=(50, 30)).numpy()
rand2d[rand2d < 0.6] = 0
self.rand2d = tf.convert_to_tensor(rand2d)

rand3d = tf.random.uniform(shape=(50, 30, 10)).numpy()
rand3d[rand3d < 0.6] = 0
self.rand3d = tf.convert_to_tensor(rand3d)

self.tensors = [
(True, tf.constant(5)),
(False, tf.constant([])),
(False, tf.constant([1, 2, 3, 4])),
(False, tf.constant([[1, 2], [1003, 1004]])),
(False, tf.sparse.from_dense(tf.constant([[1, 2], [1003, 1004]]))),
(False, self.rand1d),
(False, tf.sparse.from_dense(self.rand1d)),
(False, self.rand2d),
(False, tf.sparse.from_dense(self.rand2d)),
(True, self.rand3d),
(False, tf.sparse.from_dense(self.rand3d)),
(False, tf.sparse.SparseTensor(indices=[[0, 3], [2, 4]], values=[10, 20], dense_shape=[3, 10])),
]

def test_no_crash(self):
for fallback_ok, tensor in self.tensors:
res = to_str(tensor, title=True)
self.assertGreater(len(res), 5)

res = to_html(tensor, title=True)
self.assertGreater(len(res), 5)
if not fallback_ok:
self.assertNotIn("<pre>", res)

res = to_latex(tensor, title=True)
self.assertGreater(len(res), 5)

def test_contents_2d(self):
source_tensor, data = generate_fixed_value(8, 8)
for to_sparse in (False, True):
tensor = tf.sparse.from_dense(source_tensor) if to_sparse else source_tensor

res = to_html(tensor, notebook=False, max_rows=20, max_cols=20, title=True, indices=True)
for row in data:
for value in row:
self.assertIn(f"<td>{value}</td>", res)

trunc = to_html(tensor, notebook=False, max_rows=5, max_cols=5, title=True, indices=True)
for value in (data[0][0], data[0][-1], data[-1][0], data[-1][-1]):
self.assertIn(f"<td>{value}</td>", trunc)


if __name__ == '__main__':
unittest.main()

0 comments on commit 3de4f3f

Please sign in to comment.