Skip to content

Commit

Permalink
Add some tests covering decimals
Browse files Browse the repository at this point in the history
  • Loading branch information
philss committed Oct 8, 2024
1 parent 32c3478 commit 83afad1
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 0 deletions.
1 change: 1 addition & 0 deletions lib/explorer/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1267,6 +1267,7 @@ defmodule Explorer.Series do
def iotype(%Series{dtype: dtype}) do
case dtype do
:category -> {:u, 32}
{:decimal, _, _} -> {:s, 128}
other -> Shared.dtype_to_iotype(other)
end
end
Expand Down
1 change: 1 addition & 0 deletions native/explorer/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ pub fn s_in(s: ExSeries, rhs: ExSeries) -> Result<ExSeries, ExplorerError> {
| DataType::Binary
| DataType::Date
| DataType::Time
| DataType::Decimal(_, _)
| DataType::Datetime(_, _) => is_in(&s, &rhs)?,
DataType::Categorical(Some(mapping), _) => {
let l_logical = s.categorical()?.physical();
Expand Down
49 changes: 49 additions & 0 deletions test/explorer/series_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,13 @@ defmodule Explorer.SeriesTest do
assert Series.equal(s, "a") |> Series.to_list() == [true, false, false, nil, true]
end

test "compare decimal series" do
s1 = Series.from_list([1, 0, 2], dtype: {:decimal, nil, 2})
s2 = Series.from_list([1, 0, 3], dtype: {:decimal, nil, 2})

assert s1 |> Series.equal(s2) |> Series.to_list() == [true, true, false]
end

test "performs broadcasting" do
s1 = Series.from_list([-1, 0, 1])
s2 = Series.from_list([0])
Expand Down Expand Up @@ -969,6 +976,13 @@ defmodule Explorer.SeriesTest do
assert 2 |> Series.not_equal(s1) |> Series.to_list() == [true, true, false]
end

test "compare decimal series" do
s1 = Series.from_list([1, 0, 2], dtype: {:decimal, nil, 2})
s2 = Series.from_list([1, 0, 3], dtype: {:decimal, nil, 2})

assert s1 |> Series.not_equal(s2) |> Series.to_list() == [false, false, true]
end

test "compare float series with a float value on the left-hand side" do
s1 = Series.from_list([1.0, 2.5, :nan, :infinity, :neg_infinity])
assert 2.5 |> Series.not_equal(s1) |> Series.to_list() == [true, false, true, true, true]
Expand Down Expand Up @@ -1077,6 +1091,13 @@ defmodule Explorer.SeriesTest do
[false, false, false, false, false]
end

test "compare decimal series" do
s1 = Series.from_list([1, 0, 3], dtype: {:decimal, nil, 2})
s2 = Series.from_list([1, 0, 2], dtype: {:decimal, nil, 2})

assert s1 |> Series.greater(s2) |> Series.to_list() == [false, false, true]
end

test "compares series of different sizes" do
s1 = Series.from_list([1, 2, 3])
s2 = Series.from_list([3, 2, 1, 4])
Expand Down Expand Up @@ -1110,6 +1131,13 @@ defmodule Explorer.SeriesTest do
assert s1 |> Series.greater_equal(s2) |> Series.to_list() == [true, false, true]
end

test "compare decimal series" do
s1 = Series.from_list([1, 0, 3, -1], dtype: {:decimal, nil, 2})
s2 = Series.from_list([1, 0, 2, 1], dtype: {:decimal, nil, 2})

assert s1 |> Series.greater_equal(s2) |> Series.to_list() == [true, true, true, false]
end

test "compare integer series with a scalar value on the right-hand side" do
s1 = Series.from_list([1, 0, 2, 3])

Expand Down Expand Up @@ -1270,6 +1298,13 @@ defmodule Explorer.SeriesTest do
[true, true, true, true, false]
end

test "compare decimal series" do
s1 = Series.from_list([1, 0, 2], dtype: {:decimal, nil, 2})
s2 = Series.from_list([1, 0, 3], dtype: {:decimal, nil, 2})

assert s1 |> Series.less(s2) |> Series.to_list() == [false, false, true]
end

test "raises on value mismatch" do
assert_raise ArgumentError,
"cannot invoke Explorer.Series.less/2 with mismatched dtypes: {:f, 64} and nil",
Expand Down Expand Up @@ -1297,6 +1332,13 @@ defmodule Explorer.SeriesTest do
[true, true, true, true, true, true]
end

test "compare decimal series" do
s1 = Series.from_list([1, 0, 2], dtype: {:decimal, nil, 2})
s2 = Series.from_list([1, 0, 3], dtype: {:decimal, nil, 2})

assert s1 |> Series.less_equal(s2) |> Series.to_list() == [true, true, true]
end

test "compare time series" do
s1 = Series.from_list([~T[00:00:00.000000], ~T[12:00:00.000000], ~T[23:59:59.999999]])
s2 = Series.from_list([~T[00:00:00.000000], ~T[12:30:00.000000], ~T[23:50:59.999999]])
Expand Down Expand Up @@ -1388,6 +1430,13 @@ defmodule Explorer.SeriesTest do
assert s1 |> Series.in(s2) |> Series.to_list() == [true, false, true]
end

test "with decimal series" do
s1 = Series.from_list([1, 2, 3], dtype: {:decimal, nil, 2})
s2 = Series.from_list([1, 0, 3], dtype: {:decimal, nil, 2})

assert s1 |> Series.in(s2) |> Series.to_list() == [true, false, true]
end

test "with signed integer series and nil on the left-hand side" do
s1 = Series.from_list([1, 2, 3, nil])
s2 = Series.from_list([1, 0, 3])
Expand Down

0 comments on commit 83afad1

Please sign in to comment.