Skip to content

Commit

Permalink
Fix test warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
alugowski committed Sep 7, 2023
1 parent 8aed0e9 commit 7e2c728
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 16 deletions.
7 changes: 6 additions & 1 deletion matrepr/adapters/sparse_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: BSD-2-Clause

from typing import Any, Iterable, Tuple
import warnings

import sparse
from sparse import COO
Expand Down Expand Up @@ -64,7 +65,11 @@ def __init__(self, mat):
PyDataSparseBase.__init__(self, mat)

def get_coo(self, row_range: Tuple[int, int], col_range: Tuple[int, int]) -> Iterable[Tuple[int, int, Any]]:
ret = COO(self.mat[slice(*row_range), slice(*col_range)])
with warnings.catch_warnings():
# COO will complain about a structure it itself created
warnings.simplefilter("ignore", category=DeprecationWarning, lineno=261)

ret = COO(self.mat[slice(*row_range), slice(*col_range)])
ret.coords[0] += row_range[0]
ret.coords[1] += col_range[0]
return zip(ret.coords[0], ret.coords[1], ret.data)
Expand Down
33 changes: 19 additions & 14 deletions tests/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: BSD-2-Clause

import unittest
import warnings

try:
import sparse
Expand All @@ -11,12 +12,12 @@

from matrepr import to_html, to_latex, to_str

import numpy.random
numpy.random.seed(123)
import scipy
import numpy as np
np.random.seed(123)


def generate_fixed_value(m, n):
import scipy
row_factor = 10**(1+len(str(n)))
nnz = m*n
rows, cols, data = [1] * nnz, [1] * nnz, [1] * nnz
Expand All @@ -34,16 +35,20 @@ def generate_fixed_value(m, n):
class PyDataSparseTests(unittest.TestCase):
def setUp(self):
self.mats = [
sparse.COO([], shape=(0,)),
sparse.COO(coords=[1, 4], data=[11, 44], shape=(10,)),
sparse.COO([], shape=(0, 0)),
sparse.COO([], shape=(10, 10)),
sparse.COO(np.array([1])),
sparse.COO(coords=np.array([1, 4]), data=np.array([11, 44]), shape=(10,)),
sparse.COO(np.empty(shape=(10, 10))),
sparse.random((10, 10), density=0.4),
sparse.COO.from_scipy_sparse(generate_fixed_value(10, 10)),
sparse.COO(coords=[[0, 0], [0, 0]], data=[111, 222], shape=(13, 13)),
sparse.COO(coords=[[0, 1], [3, 2], [1, 3]], data=[111, 222], shape=(5, 5, 5)),
sparse.COO(coords=np.array([[0, 0], [0, 0]]), data=np.array([111, 222]), shape=(13, 13)), # has dupes
sparse.COO(coords=np.array([[0, 1], [3, 2], [1, 3]]), data=np.array([111, 222]), shape=(5, 5, 5)),
]

with warnings.catch_warnings():
# COO will incorrectly complain that the object is not ndarray when it is.
warnings.simplefilter("ignore", category=DeprecationWarning, lineno=261)
self.mats.append(sparse.COO(np.empty(shape=(0, 0))))

self.types = [
sparse.COO,
sparse.DOK,
Expand Down Expand Up @@ -86,14 +91,14 @@ def test_formats(self):

def test_contents_1d(self):
values = [1000, 1001, 1002, 1003, 1004]
vec = sparse.COO([0, 1, 2, 3, 4], data=values, shape=(10,))
vec = sparse.COO(np.array([0, 1, 2, 3, 4]), data=np.array(values), shape=(10,))
res = to_html(vec, notebook=False, max_rows=20, max_cols=20, title=True, indices=True)
for value in values:
self.assertIn(f"<td>{value}</td>", res)

def test_truncate_1d(self):
values = [1000, 1001, 1002, 1003, 1009]
vec = sparse.COO([0, 1, 2, 3, 9], data=values, shape=(10,))
vec = sparse.COO(np.array([0, 1, 2, 3, 9]), data=np.array(values), shape=(10,))
res = to_html(vec, notebook=False, max_rows=3, max_cols=3, num_after_dots=1, title=True, indices=True)
for value in [1000, 1009]:
self.assertIn(f"<td>{value}</td>", res)
Expand Down Expand Up @@ -123,7 +128,7 @@ def test_truncate_2d(self):

def test_contents_3d(self):
values = [111, 222]
mat = sparse.COO(coords=[[0, 1], [3, 2], [1, 3]], data=values, shape=(5, 5, 5))
mat = sparse.COO(coords=np.array([[0, 1], [3, 2], [1, 3]]), data=np.array(values), shape=(5, 5, 5))
res = to_html(mat, notebook=False, max_rows=20, max_cols=20, title=True, indices=True)
res_str = to_str(mat)
for value in values:
Expand All @@ -137,7 +142,7 @@ def test_contents_3d(self):

def test_truncate_3d(self):
values = [111, 222]
mat = sparse.COO(coords=[[0, 10], [30, 2], [1, 30]], data=values, shape=(50, 50, 50))
mat = sparse.COO(coords=np.array([[0, 10], [30, 2], [1, 30]]), data=np.array(values), shape=(50, 50, 50))

res = to_html(mat, notebook=False, max_rows=30, max_cols=3, num_after_dots=1)
res_str = to_str(mat, max_rows=30, max_cols=3, num_after_dots=1)
Expand All @@ -164,7 +169,7 @@ def test_no_comma_space(self):
self.assertNotIn(" ,", res_str)

def test_patch_sparse(self):
source_mat = sparse.COO(coords=[1, 4, 6], data=[11, 44, 222], shape=(10,))
source_mat = sparse.COO(coords=np.array([1, 4, 6]), data=np.array([11, 44, 222]), shape=(10,))

# noinspection PyUnresolvedReferences
import matrepr.patch.sparse
Expand Down
2 changes: 1 addition & 1 deletion tests/test_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def test_tabulate_forward(self):
mat = [[1, 2], [2000, 300000]]
left = to_str(mat, colalign=["left", "left"])
right = to_str(mat, colalign=["right", "right"])
self.assertNotEquals(left, right)
self.assertNotEqual(left, right)


if __name__ == '__main__':
Expand Down
3 changes: 3 additions & 0 deletions tests/test_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
# SPDX-License-Identifier: BSD-2-Clause

import unittest
import warnings

try:
# Suppress warning from inside tensorflow
warnings.filterwarnings("ignore", message="module 'sre_constants' is deprecated")
import tensorflow as tf

tf.random.set_seed(1234)
Expand Down
4 changes: 4 additions & 0 deletions tests/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: BSD-2-Clause

import unittest
import warnings

import numpy as np

Expand Down Expand Up @@ -30,6 +31,9 @@ def generate_fixed_value(m, n):
@unittest.skipIf(torch is None, "PyTorch not installed")
class PyTorchTests(unittest.TestCase):
def setUp(self):
# filter beta state warning
warnings.filterwarnings("ignore", message="Sparse CSR tensor support is in beta state")

rand2d = torch.rand(50, 30)
self.rand2d = rand2d[rand2d < 0.6] = 0

Expand Down

0 comments on commit 7e2c728

Please sign in to comment.