Skip to content

Commit

Permalink
Make column and DataPanel testing more modular (#118)
Browse files Browse the repository at this point in the history
Closes #109
The BlockManager (see #104) introduces a need for more robust DataPanel testing that tests DataPanels with a diverse set of columns. As we add more columns, we don't want to have to update the DataPanel tests for each new column. Instead, we should specify a TestBed for each column that plugs in to the DataPanel tests.

Started this for NumpyArrayColumn with #108.

Co-authored-by: Priya <[email protected]>
  • Loading branch information
seyuboglu and Priya authored Aug 9, 2021
1 parent 0122e2b commit 6ab6ba8
Show file tree
Hide file tree
Showing 27 changed files with 1,929 additions and 1,645 deletions.
4 changes: 3 additions & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@ exclude_lines =
ignore_errors = True
omit =
tests/*
meerkat/contrib/*
meerkat/contrib/*
meerkat/nn/*
setup.py
1 change: 0 additions & 1 deletion meerkat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

initialize_logging()

from meerkat import nn
from meerkat.cells.abstract import AbstractCell
from meerkat.cells.imagepath import ImagePath
from meerkat.cells.spacy import LazySpacyCell, SpacyCell
Expand Down
12 changes: 12 additions & 0 deletions meerkat/block/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,18 @@ def _repr_pandas_(self):
dfs.append(pd.DataFrame({k: self[k]._repr_pandas_() for k in cols}))
return pd.concat(objs=dfs, axis=1)

def view(self):
mgr = BlockManager()
for name, col in self._columns.items():
mgr.add_column(col.view(), name)
return mgr

def copy(self):
mgr = BlockManager()
for name, col in self._columns.items():
mgr.add_column(col.copy(), name)
return mgr


def _serialize_block_index(index: BlockIndex) -> Union[Dict, str, int]:
if not isinstance(index, (int, str, slice)):
Expand Down
4 changes: 2 additions & 2 deletions meerkat/block/numpy_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def signature(self) -> Hashable:
dtype=self.data.dtype,
)

def _get_data(self, index: BlockIndex) -> np.ndarray:
def _get_data(self, index: BlockIndex, materialize: bool = True) -> np.ndarray:
return self.data[:, index]

@classmethod
Expand All @@ -59,7 +59,7 @@ def from_data(cls, data: np.ndarray) -> Tuple[NumpyBlock, Mapping[str, BlockInde
data = np.expand_dims(data, axis=1)
block_index = 0
elif data.shape[1] == 1:
block_index = 0
block_index = slice(0, 1)
else:
block_index = slice(0, data.shape[1])

Expand Down
3 changes: 3 additions & 0 deletions meerkat/block/pandas_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def _convert_index(index):
if isinstance(index, TensorColumn):
# need to convert to numpy for boolean indexing
return index.data.numpy()
if isinstance(index, pd.Series):
# need to convert to numpy for boolean indexing
return index.values
from meerkat.columns.pandas_column import PandasSeriesColumn

if isinstance(index, PandasSeriesColumn):
Expand Down
2 changes: 1 addition & 1 deletion meerkat/block/tensor_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def from_data(
data = torch.unsqueeze(data, dim=1)
block_index = 0
elif data.shape[1] == 1:
block_index = 0
block_index = slice(0, 1)
else:
block_index = slice(0, data.shape[1])

Expand Down
28 changes: 14 additions & 14 deletions meerkat/columns/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,6 @@ def __str__(self):
def streamlit(self):
return self._repr_pandas_()

def _unpack_data(self, data):
return super(AbstractColumn, self)._unpack_data(data)

def _set_data(self, data):
if self.is_blockable():
data = self._unpack_block_view(data)
Expand Down Expand Up @@ -225,10 +222,9 @@ def __len__(self):
return self.full_length()

def full_length(self):
# Length of the underlying data stored in the column
if self._data is not None:
return len(self._data)
return 0
if self._data is None:
return 0
return len(self._data)

def _repr_pandas_(self) -> pd.Series:
raise NotImplementedError
Expand All @@ -244,7 +240,7 @@ def _repr_html_(self):
@capture_provenance()
def filter(
self,
function: Optional[Callable] = None,
function: Callable,
with_indices=False,
input_columns: Optional[Union[str, List[str]]] = None,
is_batched_fn: bool = False,
Expand All @@ -256,15 +252,11 @@ def filter(
**kwargs,
) -> Optional[AbstractColumn]:
"""Filter the elements of the column using a function."""
# Just return if the function is None
if function is None:
logger.info("`function` None, returning None.")
return None

# Return if `self` has no examples
if not len(self):
logger.info("Dataset empty, returning None.")
return None
logger.info("Dataset empty, returning it .")
return self

# Get some information about the function
function_properties = self._inspect_function(
Expand Down Expand Up @@ -304,6 +296,14 @@ def concat(columns: Sequence[AbstractColumn]) -> None:
# implement specific ones for ListColumn, NumpyColumn etc.
raise NotImplementedError

def is_equal(self, other: AbstractColumn) -> bool:
"""Tests whether two columns.
Args:
other (AbstractColumn): [description]
"""
raise NotImplementedError()

def batch(
self,
batch_size: int = 1,
Expand Down
7 changes: 7 additions & 0 deletions meerkat/columns/cell_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,10 @@ def concat(columns: Sequence[CellColumn]):
return columns[0].__class__.from_cells(
list(tz.concat([c.data for c in columns]))
)

def is_equal(self, other: AbstractColumn) -> bool:
return (
(self.__class__ == other.__class__)
and (len(self) == len(other))
and all([self.lz[idx] == other.lz[idx] for idx in range(len(self))])
)
44 changes: 40 additions & 4 deletions meerkat/columns/image_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import pandas as pd

from meerkat.cells.imagepath import ImagePath
from meerkat.columns.abstract import AbstractColumn
from meerkat.columns.cell_column import CellColumn
from meerkat.columns.lambda_column import LambdaColumn
from meerkat.columns.lambda_column import LambdaCell, LambdaColumn
from meerkat.columns.pandas_column import PandasSeriesColumn
from meerkat.tools.lazy_loader import LazyLoader

Expand All @@ -16,6 +17,32 @@
logger = logging.getLogger(__name__)


class ImageCell(LambdaCell):
def __init__(
self,
transform: callable = None,
loader: callable = None,
data: str = None,
):
self.loader = self.default_loader if loader is None else loader
self.transform = transform
self._data = data

def fn(self, filepath: str):
image = self.loader(filepath)
if self.transform is not None:
image = self.transform(image)
return image

def __eq__(self, other):
return (
(other.__class__ == self.__class__)
and (self.data == other.data)
and (self.transform == other.transform)
and (self.loader == other.loader)
)


class ImageColumn(LambdaColumn):
def __init__(
self,
Expand All @@ -25,12 +52,13 @@ def __init__(
*args,
**kwargs,
):
super(ImageColumn, self).__init__(
PandasSeriesColumn.from_data(data), *args, **kwargs
)
super(ImageColumn, self).__init__(PandasSeriesColumn(data), *args, **kwargs)
self.loader = self.default_loader if loader is None else loader
self.transform = transform

def _create_cell(self, data: object) -> ImageCell:
return ImageCell(data=data, loader=self.loader, transform=self.transform)

def fn(self, filepath: str):
image = self.loader(filepath)
if self.transform is not None:
Expand Down Expand Up @@ -65,6 +93,14 @@ def _state_keys(cls) -> Collection:
def _repr_pandas_(self) -> pd.Series:
return "ImageCell(" + self.data.data.reset_index(drop=True) + ")"

def is_equal(self, other: AbstractColumn) -> bool:
return (
(other.__class__ == self.__class__)
and (self.loader == other.loader)
and (self.transform == other.transform)
and self.data.is_equal(other.data)
)


class ImageCellColumn(CellColumn):
def __init__(self, *args, **kwargs):
Expand Down
43 changes: 32 additions & 11 deletions meerkat/columns/lambda_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ def get(self, *args, **kwargs):
else:
return self.fn(self.data)

def __eq__(self, other):
return (
(other.__class__ == self.__class__)
and (self.data == other.data)
and (self.fn == other.fn)
)


class LambdaColumn(AbstractColumn):
def __init__(
Expand All @@ -56,29 +63,35 @@ def __init__(
self.fn = fn
self._output_type = output_type

def __getattr__(self, name):
if not self._output_type:
raise AttributeError(name)
# TODO (Sabri): reconsider whether this is important functionality. it's not clear
# to me that this is that useful.
# def __getattr__(self, name):
# if not self._output_type:
# raise AttributeError(name)

# data = self[:2]
# if not hasattr(data, name):
# raise AttributeError(name)

data = self[:2]
if not hasattr(data, name):
raise AttributeError(name)
# data = self[:]
# return data.__getattr__(name

data = self[:]
return data.__getattr__(name)
def _set(self, index, value):
raise ValueError("Cannot setitem on a `LambdaColumn`.")

def fn(self, data: object):
"""Subclasses like `ImageColumn` should be able to implement their own
version."""
raise NotImplementedError

def _create_cell(self, data: object) -> LambdaCell:
return LambdaCell(fn=self.fn, data=data)

def _get_cell(self, index: int, materialize: bool = True):
if materialize:
return self.fn(self._data._get(index, materialize=True))
else:
return LambdaCell(
fn=self.fn, data=self._data._get(index, materialize=False)
)
return self._create_cell(data=self._data._get(index, materialize=False))

def _get_batch(self, indices: np.ndarray, materialize: bool = True):
if materialize:
Expand Down Expand Up @@ -141,6 +154,14 @@ def _write_data(self, path):
# TODO (Sabri): avoid redundant writes in dataframes
return self.data.write(os.path.join(path, "data"))

def is_equal(self, other: AbstractColumn) -> bool:
if other.__class__ != self.__class__:
return False
if self.fn != other.fn:
return False

return self.data.is_equal(other.data)

@staticmethod
def _read_data(path: str):
# TODO (Sabri): make this work for dataframes underlying the lambda column
Expand Down
3 changes: 3 additions & 0 deletions meerkat/columns/list_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,6 @@ def concat(cls, columns: Sequence[ListColumn]):
if issubclass(cls, CloneableMixin):
return columns[0]._clone(data=data)
return cls.from_list(data)

def is_equal(self, other: AbstractColumn) -> bool:
return (self.__class__ == other.__class__) and self.data == other.data
25 changes: 16 additions & 9 deletions meerkat/columns/numpy_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from meerkat.block.abstract import BlockView
from meerkat.block.numpy_block import NumpyBlock
from meerkat.columns.abstract import AbstractColumn
from meerkat.mixins.cloneable import CloneableMixin
from meerkat.writers.concat_writer import ConcatWriter

Representer.add_representer(abc.ABCMeta, Representer.represent_name)
Expand Down Expand Up @@ -44,7 +43,7 @@ class NumpyArrayColumn(

def __init__(
self,
data: Sequence = None,
data: Sequence,
*args,
**kwargs,
):
Expand All @@ -54,25 +53,30 @@ def __init__(
"Cannot create `NumpyArrayColumn` from a `BlockView` not "
"referencing a `NumpyBlock`."
)
elif data is not None and not isinstance(data, np.memmap):
elif not isinstance(data, np.memmap):
data = np.asarray(data)
super(NumpyArrayColumn, self).__init__(data=data, *args, **kwargs)

# TODO (sabri): need to support str here
_HANDLED_TYPES = (np.ndarray, numbers.Number)

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
def __array_ufunc__(self, ufunc: np.ufunc, method, *inputs, **kwargs):
out = kwargs.get("out", ())
for x in inputs + out:
# Only support operations with instances of _HANDLED_TYPES.
# Use ArrayLike instead of type(self) for isinstance to
# allow subclasses that don't override __array_ufunc__ to
# handle ArrayLike objects.
if not isinstance(x, self._HANDLED_TYPES + (NumpyArrayColumn,)):
if not isinstance(x, self._HANDLED_TYPES + (NumpyArrayColumn,)) and not (
# support for at index
method == "at"
and isinstance(x, list)
):
return NotImplemented

# Defer to the implementation of the ufunc on unwrapped values.
inputs = tuple(x.data if isinstance(x, NumpyArrayColumn) else x for x in inputs)

if out:
kwargs["out"] = tuple(
x.data if isinstance(x, NumpyArrayColumn) else x for x in out
Expand All @@ -87,7 +91,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
return None
else:
# one return value
return type(self)(data=result)
return self._clone(data=result)

def __getattr__(self, name):
try:
Expand Down Expand Up @@ -143,9 +147,12 @@ def _read_data(
@classmethod
def concat(cls, columns: Sequence[NumpyArrayColumn]):
data = np.concatenate([c.data for c in columns])
if issubclass(cls, CloneableMixin):
return columns[0]._clone(data=data)
return cls.from_array(data)
return columns[0]._clone(data=data)

def is_equal(self, other: AbstractColumn) -> bool:
if other.__class__ != self.__class__:
return False
return (self.data == other.data).all()

@classmethod
def get_writer(cls, mmap: bool = False, template: AbstractColumn = None):
Expand Down
Loading

0 comments on commit 6ab6ba8

Please sign in to comment.