Skip to content

Commit

Permalink
fix(rust, python): fix logical columns of streaming multi-column sort (
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Aug 2, 2023
1 parent d434aee commit cc87684
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 94 deletions.
11 changes: 6 additions & 5 deletions crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ fn finalize_dataframe(
// those need to be inserted at the `sort_idx` position
// in the `DataFrame`.
if can_decode {
let sort_dtypes = sort_dtypes.expect("should be set");
let sort_dtypes = sort_dtypes.expect("should be set if 'can_decode'");
let sort_dtypes = sort_by_idx(sort_dtypes, sort_idx);

let encoded = encoded.binary().unwrap();
Expand Down Expand Up @@ -262,10 +262,11 @@ impl Sink for SortSinkMultiple {
fn finalize(&mut self, context: &PExecutionContext) -> PolarsResult<FinalizedSink> {
let out = self.sort_sink.finalize(context)?;

let sort_dtypes = self
.sort_dtypes
.take()
.map(|arr| arr.iter().map(|dt| dt.to_arrow()).collect::<Vec<_>>());
let sort_dtypes = self.sort_dtypes.take().map(|arr| {
arr.iter()
.map(|dt| dt.to_physical().to_arrow())
.collect::<Vec<_>>()
});

// we must adapt the finalized sink result so that the sort encoded column is dropped
match out {
Expand Down
90 changes: 1 addition & 89 deletions py-polars/tests/unit/streaming/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import polars as pl
from polars.exceptions import PolarsInefficientApplyWarning
from polars.testing import assert_frame_equal, assert_series_equal
from polars.testing import assert_frame_equal

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -241,22 +241,6 @@ def test_cross_join_stack() -> None:
assert (t1 - t0) < 0.5


@pytest.mark.slow()
def test_ooc_sort(monkeypatch: Any) -> None:
monkeypatch.setenv("POLARS_FORCE_OOC", "1")

s = pl.arange(0, 100_000, eager=True).rename("idx")

df = s.shuffle().to_frame()

for descending in [True, False]:
out = (
df.lazy().sort("idx", descending=descending).collect(streaming=True)
).to_series()

assert_series_equal(out, s.sort(descending=descending))


def test_streaming_literal_expansion() -> None:
df = pl.DataFrame(
{
Expand Down Expand Up @@ -369,23 +353,6 @@ def test_streaming_unique(monkeypatch: Any, capfd: Any) -> None:
assert "df -> re-project-sink -> sort_multiple" in err


@pytest.mark.write_disk()
def test_streaming_sort(monkeypatch: Any, capfd: Any) -> None:
monkeypatch.setenv("POLARS_VERBOSE", "1")
monkeypatch.setenv("POLARS_FORCE_OOC", "1")
# this creates a lot of duplicate partitions and triggers: #7568
assert (
pl.Series(np.random.randint(0, 100, 100))
.to_frame("s")
.lazy()
.sort("s")
.collect(streaming=True)["s"]
.is_sorted()
)
(_, err) = capfd.readouterr()
assert "df -> sort" in err


@pytest.fixture(scope="module")
def random_integers() -> pl.Series:
np.random.seed(1)
Expand Down Expand Up @@ -581,61 +548,6 @@ def test_streaming_sortedness_propagation_9494() -> None:
).to_dict(False) == {"when": [date(2023, 5, 1), date(2023, 6, 1)], "what": [3, 3]}


@pytest.mark.write_disk()
def test_out_of_core_sort_9503(monkeypatch: Any) -> None:
monkeypatch.setenv("POLARS_FORCE_OOC", "1")
np.random.seed(0)

num_rows = 1_00_000
num_columns = 2
num_tables = 10

# ensure we create many chunks
# this will ensure we create more files
# and that creates contention while dumping
q = pl.concat(
[
pl.DataFrame(
[
pl.Series(np.random.randint(0, 10000, size=num_rows))
for _ in range(num_columns)
]
)
for _ in range(num_tables)
],
rechunk=False,
).lazy()
q = q.sort(q.columns)
df = q.collect(streaming=True)
assert df.shape == (1_000_000, 2)
assert df["column_0"].flags["SORTED_ASC"]
assert df.head(20).to_dict(False) == {
"column_0": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
"column_1": [
242,
245,
588,
618,
732,
902,
925,
945,
1009,
1161,
1352,
1365,
1451,
1581,
1778,
1836,
1976,
2091,
2120,
2124,
],
}


@pytest.mark.write_disk()
@pytest.mark.slow()
def test_streaming_generic_left_and_inner_join_from_disk(tmp_path: Path) -> None:
Expand Down
119 changes: 119 additions & 0 deletions py-polars/tests/unit/streaming/test_streaming_sort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from datetime import datetime
from typing import Any

import numpy as np
import pytest

import polars as pl
from polars.testing import assert_series_equal


def test_streaming_sort_multiple_columns_logical_types() -> None:
data = {
"foo": [3, 2, 1],
"bar": ["a", "b", "c"],
"baz": [
datetime(2023, 5, 1, 15, 45),
datetime(2023, 5, 1, 13, 45),
datetime(2023, 5, 1, 14, 45),
],
}
assert pl.DataFrame(data).lazy().sort("foo", "baz").collect(streaming=True).to_dict(
False
) == {
"foo": [1, 2, 3],
"bar": ["c", "b", "a"],
"baz": [
datetime(2023, 5, 1, 14, 45),
datetime(2023, 5, 1, 13, 45),
datetime(2023, 5, 1, 15, 45),
],
}


@pytest.mark.slow()
def test_ooc_sort(monkeypatch: Any) -> None:
monkeypatch.setenv("POLARS_FORCE_OOC", "1")

s = pl.arange(0, 100_000, eager=True).rename("idx")

df = s.shuffle().to_frame()

for descending in [True, False]:
out = (
df.lazy().sort("idx", descending=descending).collect(streaming=True)
).to_series()

assert_series_equal(out, s.sort(descending=descending))


@pytest.mark.write_disk()
def test_streaming_sort(monkeypatch: Any, capfd: Any) -> None:
monkeypatch.setenv("POLARS_VERBOSE", "1")
monkeypatch.setenv("POLARS_FORCE_OOC", "1")
# this creates a lot of duplicate partitions and triggers: #7568
assert (
pl.Series(np.random.randint(0, 100, 100))
.to_frame("s")
.lazy()
.sort("s")
.collect(streaming=True)["s"]
.is_sorted()
)
(_, err) = capfd.readouterr()
assert "df -> sort" in err


@pytest.mark.write_disk()
def test_out_of_core_sort_9503(monkeypatch: Any) -> None:
monkeypatch.setenv("POLARS_FORCE_OOC", "1")
np.random.seed(0)

num_rows = 1_00_000
num_columns = 2
num_tables = 10

# ensure we create many chunks
# this will ensure we create more files
# and that creates contention while dumping
q = pl.concat(
[
pl.DataFrame(
[
pl.Series(np.random.randint(0, 10000, size=num_rows))
for _ in range(num_columns)
]
)
for _ in range(num_tables)
],
rechunk=False,
).lazy()
q = q.sort(q.columns)
df = q.collect(streaming=True)
assert df.shape == (1_000_000, 2)
assert df["column_0"].flags["SORTED_ASC"]
assert df.head(20).to_dict(False) == {
"column_0": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
"column_1": [
242,
245,
588,
618,
732,
902,
925,
945,
1009,
1161,
1352,
1365,
1451,
1581,
1778,
1836,
1976,
2091,
2120,
2124,
],
}

0 comments on commit cc87684

Please sign in to comment.