Skip to content

Commit

Permalink
test that sampling-less cpp version is used when window_length has sa…
Browse files Browse the repository at this point in the history
…me sampling as input, fix it
  • Loading branch information
ianspektor committed Sep 12, 2023
1 parent 218a65d commit 196a920
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 8 deletions.
11 changes: 6 additions & 5 deletions temporian/core/event_set_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2197,14 +2197,15 @@ def set_index(
Args:
indexes: List of index / feature names (strings) used as
the new indexes. These names should be either indexes or features in
the input.
the new indexes. These names should be either indexes or
features in the input.
Returns:
EventSet with the updated indexes.
Raises:
KeyError: If any of the specified `indexes` are not found in the input.
KeyError: If any of the specified `indexes` are not found in the
input.
"""
from temporian.core.operators.add_index import set_index

Expand Down Expand Up @@ -2252,8 +2253,8 @@ def simple_moving_average(
```
See [`EventSet.moving_count()`][temporian.EventSet.moving_count] for examples of moving window
operations with external sampling and indices.
See [`EventSet.moving_count()`][temporian.EventSet.moving_count] for
examples of moving window operations with external sampling and indices.
Args:
window_length: Sliding window's length.
Expand Down
32 changes: 32 additions & 0 deletions temporian/implementation/numpy/operators/test/moving_sum_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import math
from unittest.mock import patch

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -270,6 +271,37 @@ def test_with_variable_winlen_different_sampling(self):

self.assertEqual(output["output"], expected_output)

@patch(
"temporian.implementation.numpy.operators.window.moving_sum.operators_cc.moving_sum"
)
def test_with_variable_winlen_same_sampling_uses_correct_cpp_impl(
self, cpp_moving_sum_mock
):
"""Checks that the no-sampling version of cpp code is called when
passing a variable window_length with same sampling as the input."""
evset = from_pandas(
pd.DataFrame([[1, 10.0]], columns=["timestamp", "a"])
)
window_length = from_pandas(
pd.DataFrame([[1, 1.0]], columns=["timestamp", "length"]),
same_sampling_as=evset,
)

op = MovingSumOperator(
input=evset.node(),
window_length=window_length.node(),
)
instance = MovingSumNumpyImplementation(op)

instance(input=evset, window_length=window_length)

# sampling_timestamps not passed
cpp_moving_sum_mock.assert_called_once_with(
evset_timestamps=evset.data[()].timestamps,
evset_values=evset.data[()].features[0],
window_length=window_length.data[()].features[0],
)

def test_with_sampling_and_variable_winlen_error(self):
evset = from_pandas(
pd.DataFrame([[1, 10.0]], columns=["timestamp", "a"])
Expand Down
11 changes: 8 additions & 3 deletions temporian/implementation/numpy/operators/window/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ def __call__(
assert sampling is not None
effective_sampling = sampling

# check that sampling isn't the input's, in which case we don't pass it
# to cpp impl to use the more efficient sampling-less version
has_sampling = (
effective_sampling.node().sampling_node
is not input.node().sampling_node
)

# create destination evset
output_schema = self.operator.outputs["output"].schema
output_evset = EventSet(data={}, schema=output_schema)
Expand Down Expand Up @@ -83,9 +90,7 @@ def __call__(
effective_window_length = self.operator.window_length

sampling_timestamps = (
sampling_data.timestamps
if effective_sampling is not input
else None
sampling_data.timestamps if has_sampling else None
)

if index_key in input.data:
Expand Down

0 comments on commit 196a920

Please sign in to comment.