diff --git a/.github/scripts/test_bytecode_parser.py b/.github/scripts/test_bytecode_parser.py index 50073c1b0035c..2abe9bf5de138 100644 --- a/.github/scripts/test_bytecode_parser.py +++ b/.github/scripts/test_bytecode_parser.py @@ -12,13 +12,14 @@ Running it without `PYTHONPATH` set will result in the test failing. """ + import datetime as dt # noqa: F401 import subprocess from datetime import datetime # noqa: F401 from typing import Any, Callable import pytest -from polars.utils.udfs import BytecodeParser +from polars._utils.udfs import BytecodeParser from tests.unit.operations.map.test_inefficient_map_warning import ( MY_DICT, NOOP_TEST_CASES, @@ -44,7 +45,7 @@ def test_bytecode_parser_expression_in_ipython( col: str, func: Callable[[Any], Any], expected: str ) -> None: script = ( - "from polars.utils.udfs import BytecodeParser; " + "from polars._utils.udfs import BytecodeParser; " "import datetime as dt; " "from datetime import datetime; " "import numpy as np; " @@ -73,7 +74,7 @@ def test_bytecode_parser_expression_noop(func: str) -> None: ) def test_bytecode_parser_expression_noop_in_ipython(func: str) -> None: script = ( - "from polars.utils.udfs import BytecodeParser; " + "from polars._utils.udfs import BytecodeParser; " f"MY_DICT = {MY_DICT};" f'parser = BytecodeParser({func}, map_target="expr");' f'print(not parser.can_attempt_rewrite() or not parser.to_expression("x"));' diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 2253e05861e48..e616c3f3f53db 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -87,9 +87,7 @@ jobs: env: RUSTFLAGS: -C embed-bitcode -D warnings working-directory: py-polars - run: | - source activate - maturin develop --release -- -C codegen-units=8 -C lto=thin -C target-cpu=native + run: maturin develop --release -- -C codegen-units=8 -C lto=thin -C target-cpu=native - name: Run H2O AI database benchmark - on strings working-directory: py-polars/tests/benchmark diff --git a/.github/workflows/codecov.yml b/.github/workflows/codecov.yml new file mode 100644 index 0000000000000..3e93c7c0e3dc3 --- /dev/null +++ b/.github/workflows/codecov.yml @@ -0,0 +1,105 @@ +name: Code coverage + +on: + pull_request: + paths: + - '**.rs' + - '**.py' + - .github/workflows/codecov.yml + push: + branches: + - main + paths: + - '**.rs' + - '**.py' + - .github/workflows/codecov.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +defaults: + run: + working-directory: py-polars + shell: bash + +jobs: + coverage: + name: Code Coverage + runs-on: macos-latest + env: + RUSTFLAGS: '-C instrument-coverage --cfg=coverage --cfg=coverage_nightly --cfg=trybuild_no_target' + RUST_BACKTRACE: 1 + LLVM_PROFILE_FILE: '/Users/runner/work/polars/polars/target/polars-%p-%3m.profraw' + CARGO_LLVM_COV: 1 + CARGO_LLVM_COV_SHOW_ENV: 1 + CARGO_LLVM_COV_TARGET_DIR: '/Users/runner/work/polars/polars/target' + + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Create virtual environment + run: | + python -m venv .venv + echo "$GITHUB_WORKSPACE/py-polars/.venv/bin" >> $GITHUB_PATH + + - name: Install dependencies + run: pip install -r requirements-dev.txt + + - name: Set up Rust + run: rustup component add llvm-tools-preview + + - name: Install cargo-llvm-cov + uses: taiki-e/install-action@cargo-llvm-cov + + - uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref_name == 'main' }} + + - name: Prepare coverage + run: cargo llvm-cov clean --workspace + + - name: Run tests + run: > + cargo test --all-features + -p polars-arrow + -p polars-compute + -p polars-core + -p polars-io + -p polars-lazy + -p polars-ops + -p polars-plan + -p polars-row + -p polars-sql + -p polars-time + -p polars-utils + + - name: Run Rust integration tests + run: cargo test --all-features -p polars --test it + + - name: Install Polars + run: maturin develop + + - name: Run Python tests + run: pytest --cov -n auto --dist loadgroup -m "not benchmark and not docs" --cov-report xml:main.xml + continue-on-error: true + + - name: Run Python tests - async reader + env: + POLARS_FORCE_ASYNC: 1 + run: pytest --cov -m "not benchmark and not docs" tests/unit/io/ --cov-report xml:async.xml + continue-on-error: true + + - name: Report coverage + run: cargo llvm-cov report --lcov --output-path coverage.lcov + + - name: Upload coverage information + uses: codecov/codecov-action@v4 + with: + files: py-polars/coverage.lcov,py-polars/main.xml,py-polars/async.xml + name: macos + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/docs-global.yml b/.github/workflows/docs-global.yml index f24a39d2be727..cddd3fcbced31 100644 --- a/.github/workflows/docs-global.yml +++ b/.github/workflows/docs-global.yml @@ -82,9 +82,7 @@ jobs: - name: Install Polars working-directory: py-polars - run: | - source activate - maturin develop + run: maturin develop - name: Set up Graphviz uses: ts-graphviz/setup-graphviz@v2 diff --git a/.github/workflows/lint-global.yml b/.github/workflows/lint-global.yml index d3383dc164fcc..c01c96ad01119 100644 --- a/.github/workflows/lint-global.yml +++ b/.github/workflows/lint-global.yml @@ -15,4 +15,4 @@ jobs: - name: Lint Markdown and TOML uses: dprint/check@v2.2 - name: Spell Check with Typos - uses: crate-ci/typos@v1.17.2 + uses: crate-ci/typos@v1.18.2 diff --git a/.github/workflows/release-python.yml b/.github/workflows/release-python.yml index ac1a272746eb9..b4ccfb7748b3f 100644 --- a/.github/workflows/release-python.yml +++ b/.github/workflows/release-python.yml @@ -129,13 +129,14 @@ jobs: env: IS_LTS_CPU: ${{ matrix.package == 'polars-lts-cpu' }} IS_MACOS: ${{ matrix.os == 'macos-latest' }} + # IMPORTANT: All features enabled here should also be included in py-polars/polars/_cpu_check.py run: | if [[ "$IS_LTS_CPU" = true ]]; then FEATURES=+sse3,+ssse3,+sse4.1,+sse4.2,+popcnt elif [[ "$IS_MACOS" = true ]]; then - FEATURES=+sse3,+ssse3,+sse4.1,+sse4.2,+popcnt,+avx,+fma + FEATURES=+sse3,+ssse3,+sse4.1,+sse4.2,+popcnt,+avx,+fma,+pclmulqdq else - FEATURES=+sse3,+ssse3,+sse4.1,+sse4.2,+popcnt,+avx,+avx2,+fma,+bmi1,+bmi2,+lzcnt + FEATURES=+sse3,+ssse3,+sse4.1,+sse4.2,+popcnt,+avx,+avx2,+fma,+bmi1,+bmi2,+lzcnt,+pclmulqdq fi echo "features=$FEATURES" >> $GITHUB_OUTPUT @@ -178,6 +179,13 @@ jobs: --out dist manylinux: ${{ matrix.architecture == 'aarch64' && '2_24' || 'auto' }} + - name: Test wheel + # Only test on x86-64 for now as this matches the runner architecture + if: matrix.architecture == 'x86-64' + run: | + pip install --force-reinstall --verbose dist/*.whl + python -c 'import polars' + - name: Upload wheel uses: actions/upload-artifact@v4 with: diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index 55540a3917ecc..f6ac631177c82 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -73,9 +73,7 @@ jobs: save-if: ${{ github.ref_name == 'main' }} - name: Install Polars - run: | - source activate - maturin develop + run: maturin develop - name: Run doctests if: github.ref_name != 'main' && matrix.python-version == '3.12' && matrix.os == 'ubuntu-latest' @@ -94,7 +92,9 @@ jobs: - name: Run tests async reader tests if: github.ref_name != 'main' && matrix.os != 'windows-latest' - run: POLARS_FORCE_ASYNC=1 pytest -m "not benchmark and not docs" tests/unit/io/ + env: + POLARS_FORCE_ASYNC: 1 + run: pytest -m "not benchmark and not docs" tests/unit/io/ - name: Check import without optional dependencies if: github.ref_name != 'main' && matrix.python-version == '3.12' && matrix.os == 'ubuntu-latest' diff --git a/.github/workflows/test-rust.yml b/.github/workflows/test-rust.yml index 875b3c3fba12c..80c6af2733e6b 100644 --- a/.github/workflows/test-rust.yml +++ b/.github/workflows/test-rust.yml @@ -46,6 +46,7 @@ jobs: run: > cargo test --all-features --no-run -p polars-arrow + -p polars-compute -p polars-core -p polars-io -p polars-lazy @@ -61,6 +62,7 @@ jobs: run: > cargo test --all-features -p polars-arrow + -p polars-compute -p polars-core -p polars-io -p polars-lazy diff --git a/Cargo.lock b/Cargo.lock index 888ed5074cd5c..94cf4ec37d21f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -25,9 +25,9 @@ checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234" [[package]] name = "ahash" -version = "0.8.7" +version = "0.8.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77c3a9648d43b9cd48db467b3f87fdd6e146bcc88ab0180006cef2179fe11d01" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", "const-random", @@ -90,21 +90,46 @@ checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" [[package]] name = "anstyle" -version = "1.0.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2faccea4cc4ab4a667ce676a30e8ec13922a692c99bb8f5b11f1502c72e04220" +checksum = "8901269c6307e8d93993578286ac0edf7f195079ffff5ebdeea6a59ffb7e36bc" [[package]] name = "anyhow" -version = "1.0.79" +version = "1.0.80" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ad32ce52e4161730f7098c077cd2ed6229b5804ccf99e5366be1ab72a98b4e1" + +[[package]] +name = "apache-avro" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "080e9890a082662b09c1ad45f567faeeb47f22b5fb23895fbe1e651e718e25ca" +checksum = "ceb7c683b2f8f40970b70e39ff8be514c95b96fcb9c4af87e1ed2cb2e10801a0" +dependencies = [ + "crc32fast", + "digest", + "lazy_static", + "libflate 2.0.0", + "log", + "num-bigint", + "quad-rand", + "rand", + "regex-lite", + "serde", + "serde_json", + "snap", + "strum", + "strum_macros", + "thiserror", + "typed-builder", + "uuid", +] [[package]] name = "argminmax" -version = "0.6.1" +version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "202108b46429b765ef483f8a24d5c46f48c14acfdacc086dd4ab6dddf6bcdbd2" +checksum = "52424b59d69d69d5056d508b260553afd91c57e21849579cd1f50ee8b8b88eaa" dependencies = [ "num-traits", ] @@ -199,7 +224,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.52", ] [[package]] @@ -210,7 +235,7 @@ checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.52", ] [[package]] @@ -244,7 +269,7 @@ dependencies = [ "crc", "fallible-streaming-iterator", "futures", - "libflate", + "libflate 1.4.0", "serde", "serde_json", "snap", @@ -252,9 +277,9 @@ dependencies = [ [[package]] name = "aws-config" -version = "1.1.4" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b30c39ebe61f75d1b3785362b1586b41991873c9ab3e317a9181c246fb71d82" +checksum = "0b96342ea8948ab9bef3e6234ea97fc32e2d8a88d8fb6a084e52267317f94b6b" dependencies = [ "aws-credential-types", "aws-runtime", @@ -271,7 +296,7 @@ dependencies = [ "bytes", "fastrand", "hex", - "http 0.2.11", + "http 0.2.12", "hyper", "ring", "time", @@ -282,9 +307,9 @@ dependencies = [ [[package]] name = "aws-credential-types" -version = "1.1.4" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33cc49dcdd31c8b6e79850a179af4c367669150c7ac0135f176c61bec81a70f7" +checksum = "273fa47dafc9ef14c2c074ddddbea4561ff01b7f68d5091c0e9737ced605c01d" dependencies = [ "aws-smithy-async", "aws-smithy-runtime-api", @@ -294,9 +319,9 @@ dependencies = [ [[package]] name = "aws-runtime" -version = "1.1.4" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb031bff99877c26c28895766f7bb8484a05e24547e370768d6cc9db514662aa" +checksum = "6e38bab716c8bf07da24be07ecc02e0f5656ce8f30a891322ecdcb202f943b85" dependencies = [ "aws-credential-types", "aws-sigv4", @@ -308,7 +333,7 @@ dependencies = [ "aws-types", "bytes", "fastrand", - "http 0.2.11", + "http 0.2.12", "http-body", "percent-encoding", "pin-project-lite", @@ -318,9 +343,9 @@ dependencies = [ [[package]] name = "aws-sdk-s3" -version = "1.14.0" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "951f7730f51a2155c711c85c79f337fbc02a577fa99d2a0a8059acfce5392113" +checksum = "93d35d39379445970fc3e4ddf7559fff2c32935ce0b279f9cb27080d6b7c6d94" dependencies = [ "aws-credential-types", "aws-runtime", @@ -336,7 +361,7 @@ dependencies = [ "aws-smithy-xml", "aws-types", "bytes", - "http 0.2.11", + "http 0.2.12", "http-body", "once_cell", "percent-encoding", @@ -347,9 +372,9 @@ dependencies = [ [[package]] name = "aws-sdk-sso" -version = "1.12.0" +version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f486420a66caad72635bc2ce0ff6581646e0d32df02aa39dc983bfe794955a5b" +checksum = "d84bd3925a17c9adbf6ec65d52104a44a09629d8f70290542beeee69a95aee7f" dependencies = [ "aws-credential-types", "aws-runtime", @@ -361,7 +386,7 @@ dependencies = [ "aws-smithy-types", "aws-types", "bytes", - "http 0.2.11", + "http 0.2.12", "once_cell", "regex-lite", "tracing", @@ -369,9 +394,9 @@ dependencies = [ [[package]] name = "aws-sdk-ssooidc" -version = "1.12.0" +version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39ddccf01d82fce9b4a15c8ae8608211ee7db8ed13a70b514bbfe41df3d24841" +checksum = "2c2dae39e997f58bc4d6292e6244b26ba630c01ab671b6f9f44309de3eb80ab8" dependencies = [ "aws-credential-types", "aws-runtime", @@ -383,7 +408,7 @@ dependencies = [ "aws-smithy-types", "aws-types", "bytes", - "http 0.2.11", + "http 0.2.12", "once_cell", "regex-lite", "tracing", @@ -391,9 +416,9 @@ dependencies = [ [[package]] name = "aws-sdk-sts" -version = "1.12.0" +version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a591f8c7e6a621a501b2b5d2e88e1697fcb6274264523a6ad4d5959889a41ce" +checksum = "17fd9a53869fee17cea77e352084e1aa71e2c5e323d974c13a9c2bcfd9544c7f" dependencies = [ "aws-credential-types", "aws-runtime", @@ -406,7 +431,7 @@ dependencies = [ "aws-smithy-types", "aws-smithy-xml", "aws-types", - "http 0.2.11", + "http 0.2.12", "once_cell", "regex-lite", "tracing", @@ -414,9 +439,9 @@ dependencies = [ [[package]] name = "aws-sigv4" -version = "1.1.4" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c371c6b0ac54d4605eb6f016624fb5c7c2925d315fdf600ac1bf21b19d5f1742" +checksum = "8ada00a4645d7d89f296fe0ddbc3fe3554f03035937c849a05d37ddffc1f29a1" dependencies = [ "aws-credential-types", "aws-smithy-eventstream", @@ -428,8 +453,8 @@ dependencies = [ "form_urlencoded", "hex", "hmac", - "http 0.2.11", - "http 1.0.0", + "http 0.2.12", + "http 1.1.0", "once_cell", "p256", "percent-encoding", @@ -443,9 +468,9 @@ dependencies = [ [[package]] name = "aws-smithy-async" -version = "1.1.4" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72ee2d09cce0ef3ae526679b522835d63e75fb427aca5413cd371e490d52dcc6" +checksum = "fcf7f09a27286d84315dfb9346208abb3b0973a692454ae6d0bc8d803fcce3b4" dependencies = [ "futures-util", "pin-project-lite", @@ -454,9 +479,9 @@ dependencies = [ [[package]] name = "aws-smithy-checksums" -version = "0.60.4" +version = "0.60.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be2acd1b9c6ae5859999250ed5a62423aedc5cf69045b844432de15fa2f31f2b" +checksum = "0fd4b66f2a8e7c84d7e97bda2666273d41d2a2e25302605bcf906b7b2661ae5e" dependencies = [ "aws-smithy-http", "aws-smithy-types", @@ -464,7 +489,7 @@ dependencies = [ "crc32c", "crc32fast", "hex", - "http 0.2.11", + "http 0.2.12", "http-body", "md-5", "pin-project-lite", @@ -486,9 +511,9 @@ dependencies = [ [[package]] name = "aws-smithy-http" -version = "0.60.4" +version = "0.60.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dab56aea3cd9e1101a0a999447fb346afb680ab1406cebc44b32346e25b4117d" +checksum = "b6ca214a6a26f1b7ebd63aa8d4f5e2194095643023f9608edf99a58247b9d80d" dependencies = [ "aws-smithy-eventstream", "aws-smithy-runtime-api", @@ -496,7 +521,7 @@ dependencies = [ "bytes", "bytes-utils", "futures-core", - "http 0.2.11", + "http 0.2.12", "http-body", "once_cell", "percent-encoding", @@ -507,18 +532,18 @@ dependencies = [ [[package]] name = "aws-smithy-json" -version = "0.60.4" +version = "0.60.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd3898ca6518f9215f62678870064398f00031912390efd03f1f6ef56d83aa8e" +checksum = "1af80ecf3057fb25fe38d1687e94c4601a7817c6a1e87c1b0635f7ecb644ace5" dependencies = [ "aws-smithy-types", ] [[package]] name = "aws-smithy-query" -version = "0.60.4" +version = "0.60.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bda4b1dfc9810e35fba8a620e900522cd1bd4f9578c446e82f49d1ce41d2e9f9" +checksum = "eb27084f72ea5fc20033efe180618677ff4a2f474b53d84695cfe310a6526cbc" dependencies = [ "aws-smithy-types", "urlencoding", @@ -526,9 +551,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.1.4" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fafdab38f40ad7816e7da5dec279400dd505160780083759f01441af1bbb10ea" +checksum = "fbb5fca54a532a36ff927fbd7407a7c8eb9c3b4faf72792ba2965ea2cad8ed55" dependencies = [ "aws-smithy-async", "aws-smithy-http", @@ -537,7 +562,7 @@ dependencies = [ "bytes", "fastrand", "h2", - "http 0.2.11", + "http 0.2.12", "http-body", "hyper", "hyper-rustls", @@ -551,14 +576,15 @@ dependencies = [ [[package]] name = "aws-smithy-runtime-api" -version = "1.1.4" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c18276dd28852f34b3bf501f4f3719781f4999a51c7bff1a5c6dc8c4529adc29" +checksum = "22389cb6f7cac64f266fb9f137745a9349ced7b47e0d2ba503e9e40ede4f7060" dependencies = [ "aws-smithy-async", "aws-smithy-types", "bytes", - "http 0.2.11", + "http 0.2.12", + "http 1.1.0", "pin-project-lite", "tokio", "tracing", @@ -567,15 +593,15 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.1.4" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb3e134004170d3303718baa2a4eb4ca64ee0a1c0a7041dca31b38be0fb414f3" +checksum = "f081da5481210523d44ffd83d9f0740320050054006c719eae0232d411f024d3" dependencies = [ "base64-simd", "bytes", "bytes-utils", "futures-core", - "http 0.2.11", + "http 0.2.12", "http-body", "itoa", "num-integer", @@ -590,24 +616,24 @@ dependencies = [ [[package]] name = "aws-smithy-xml" -version = "0.60.4" +version = "0.60.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8604a11b25e9ecaf32f9aa56b9fe253c5e2f606a3477f0071e96d3155a5ed218" +checksum = "0fccd8f595d0ca839f9f2548e66b99514a85f92feb4c01cf2868d93eb4888a42" dependencies = [ "xmlparser", ] [[package]] name = "aws-types" -version = "1.1.4" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "789bbe008e65636fe1b6dbbb374c40c8960d1232b96af5ff4aec349f9c4accf4" +checksum = "d07c63521aa1ea9a9f92a701f1a08ce3fd20b46c6efc0d5c8947c1fd879e3df1" dependencies = [ "aws-credential-types", "aws-smithy-async", "aws-smithy-runtime-api", "aws-smithy-types", - "http 0.2.11", + "http 0.2.12", "rustc_version", "tracing", ] @@ -722,15 +748,15 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.14.0" +version = "3.15.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" +checksum = "7ff69b9dd49fd426c69a0db9fc04dd934cdb6645ff000864d98f7e2af8830eaa" [[package]] name = "bytemuck" -version = "1.14.1" +version = "1.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed2490600f404f2b94c167e31d3ed1d5f3c225a0f3b80230053b3e0b7b962bd9" +checksum = "a2ef034f05691a48569bd920a96c81b9d91bbad1ab5ac7c4616c1f6ef36cb79f" dependencies = [ "bytemuck_derive", ] @@ -743,7 +769,7 @@ checksum = "965ab7eb5f8f97d2a083c799f3a1b994fc397b2fe2da5d1da1626ce15a39f2b1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.52", ] [[package]] @@ -791,9 +817,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.0.83" +version = "1.0.90" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" +checksum = "8cd6604a82acf3039f1144f54b8eb34e91ffba622051189e71b781822d5ee1f5" dependencies = [ "jobserver", "libc", @@ -807,22 +833,22 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.33" +version = "0.4.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f13690e35a5e4ace198e7beea2895d29f3a9cc55015fcebe6336bd2010af9eb" +checksum = "8eaf5903dcbc0a39312feb77df2ff4c76387d591b9fc7b04a238dcf8bb62639a" dependencies = [ "android-tzdata", "iana-time-zone", "num-traits", "serde", - "windows-targets 0.52.0", + "windows-targets 0.52.4", ] [[package]] name = "chrono-tz" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91d7b79e99bfaa0d47da0687c43aa3b7381938a62ad3a6498599039321f660b7" +checksum = "d59ae0466b83e838b81a54256c39d5d7c20b9d7daa10510a242d9b75abd5936e" dependencies = [ "chrono", "chrono-tz-build", @@ -869,18 +895,18 @@ dependencies = [ [[package]] name = "clap" -version = "4.4.18" +version = "4.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e578d6ec4194633722ccf9544794b71b1385c3c027efe0c55db226fc880865c" +checksum = "b230ab84b0ffdf890d5a10abdbc8b83ae1c4918275daea1ab8801f71536b2651" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" -version = "4.4.18" +version = "4.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4df4df40ec50c46000231c914968278b1eb05098cf8f1b3a518a95030e71d1c7" +checksum = "ae129e2e766ae0ec03484e609954119f123cc1fe650337e155d03b022f24f7b4" dependencies = [ "anstyle", "clap_lex", @@ -888,9 +914,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.6.0" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "702fc72eb24e5a1e48ce58027a675bc24edd52096d5397d4aea7c6dd9eca0bd1" +checksum = "98cc8fbded0c607b7ba9dd60cd98df59af97e84d24e49c8557331cfc26d301ce" [[package]] name = "cmake" @@ -921,9 +947,9 @@ checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" [[package]] name = "const-random" -version = "0.1.17" +version = "0.1.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aaf16c9c2c612020bcfd042e170f6e32de9b9d75adb5277cdbbd2e2c8c8299a" +checksum = "87e00182fe74b066627d63b85fd550ac2998d4b0bd86bfed477a0ae4c7c71359" dependencies = [ "const-random-macro", ] @@ -955,6 +981,15 @@ version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" +[[package]] +name = "core2" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b49ba7ef1ad6107f8824dbe97de947cbaac53c44e7f9756a1fba0d37c1eec505" +dependencies = [ + "memchr", +] + [[package]] name = "cpufeatures" version = "0.2.12" @@ -981,18 +1016,18 @@ checksum = "ccaeedb56da03b09f598226e25e80088cb4cd25f316e6e4df7d695f0feeb1403" [[package]] name = "crc32c" -version = "0.6.4" +version = "0.6.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8f48d60e5b4d2c53d5c2b1d8a58c849a70ae5e5509b08a48d047e3b65714a74" +checksum = "89254598aa9b9fa608de44b3ae54c810f0f06d755e24c50177f1f8f31ff50ce2" dependencies = [ "rustc_version", ] [[package]] name = "crc32fast" -version = "1.3.2" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b540bd8bc810d3885c6ea91e2018302f68baba2129ab3e88f32389ee9370880d" +checksum = "b3855a8a784b474f333699ef2bbca9db2c4a1f6d9088a90a2d25b1eb53111eaa" dependencies = [ "cfg-if", ] @@ -1035,9 +1070,9 @@ dependencies = [ [[package]] name = "crossbeam-channel" -version = "0.5.11" +version = "0.5.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "176dc175b78f56c0f321911d9c8eb2b77a78a4860b9c19db83835fea1a46649b" +checksum = "ab3db02a9c5b5121e1e42fbdb1aeb65f5e02624cc58c43f2884c6ccac0b82f95" dependencies = [ "crossbeam-utils", ] @@ -1136,6 +1171,12 @@ dependencies = [ "typenum", ] +[[package]] +name = "dary_heap" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7762d17f1241643615821a8455a0b2c3e803784b058693d990b11f2dce25a0ca" + [[package]] name = "der" version = "0.6.1" @@ -1174,9 +1215,9 @@ checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" [[package]] name = "dyn-clone" -version = "1.0.16" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "545b22097d44f8a9581187cdf93de7a71e4722bf51200cfaba810865b49a495d" +checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125" [[package]] name = "ecdsa" @@ -1192,9 +1233,9 @@ dependencies = [ [[package]] name = "either" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a" [[package]] name = "elliptic-curve" @@ -1234,7 +1275,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.52", ] [[package]] @@ -1394,7 +1435,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.52", ] [[package]] @@ -1458,9 +1499,9 @@ checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" [[package]] name = "git2" -version = "0.18.1" +version = "0.18.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbf97ba92db08df386e10c8ede66a2a0369bd277090afd8710e19e38de9ec0cd" +checksum = "1b3ba52851e73b46a4c3df1d89343741112003f0f6f13beb0dfac9e457c3fdcd" dependencies = [ "bitflags 2.4.2", "libc", @@ -1497,7 +1538,7 @@ dependencies = [ "futures-core", "futures-sink", "futures-util", - "http 0.2.11", + "http 0.2.12", "indexmap", "slab", "tokio", @@ -1507,9 +1548,9 @@ dependencies = [ [[package]] name = "half" -version = "2.3.1" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc52e53916c08643f1b56ec082790d1e86a32e58dc5268f897f313fbae7b4872" +checksum = "b5eceaaeec696539ddaf7b333340f1af35a5aa87ae3e4f3ead0532f72affab2e" dependencies = [ "cfg-if", "crunchy", @@ -1518,11 +1559,11 @@ dependencies = [ [[package]] name = "halfbrown" -version = "0.2.4" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5681137554ddff44396e5f149892c769d45301dd9aa19c51602a89ee214cb0ec" +checksum = "8588661a8607108a5ca69cab034063441a0413a0b041c13618a7dd348021ef6f" dependencies = [ - "hashbrown 0.13.2", + "hashbrown 0.14.3", "serde", ] @@ -1560,9 +1601,9 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" [[package]] name = "hermit-abi" -version = "0.3.4" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d3d0e0f38255e7fa3cf31335b3a56f05febd18025f4db5ef7a0cfb4f8da651f" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" [[package]] name = "hex" @@ -1590,9 +1631,9 @@ dependencies = [ [[package]] name = "http" -version = "0.2.11" +version = "0.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8947b1a6fad4393052c7ba1f4cd97bed3e953a95c79c92ad9b051a04611d9fbb" +checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1" dependencies = [ "bytes", "fnv", @@ -1601,9 +1642,9 @@ dependencies = [ [[package]] name = "http" -version = "1.0.0" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b32afd38673a8016f7c9ae69e5af41a58f81b1d31689040f2f1959594ce194ea" +checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" dependencies = [ "bytes", "fnv", @@ -1617,7 +1658,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" dependencies = [ "bytes", - "http 0.2.11", + "http 0.2.12", "pin-project-lite", ] @@ -1650,7 +1691,7 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http 0.2.11", + "http 0.2.12", "http-body", "httparse", "httpdate", @@ -1670,7 +1711,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" dependencies = [ "futures-util", - "http 0.2.11", + "http 0.2.12", "hyper", "log", "rustls", @@ -1681,9 +1722,9 @@ dependencies = [ [[package]] name = "iana-time-zone" -version = "0.1.59" +version = "0.1.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6a67363e2aa4443928ce15e57ebae94fd8949958fd1223c4cfc0cd473ad7539" +checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" dependencies = [ "android_system_properties", "core-foundation-sys", @@ -1714,9 +1755,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.2.2" +version = "2.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "824b2ae422412366ba479e8111fd301f7b5faece8149317bb81925979a53f520" +checksum = "7b0b929d511467233429c45a44ac1dcaa21ba0f5ba11e4879e6ed28ddb4f9df4" dependencies = [ "equivalent", "hashbrown 0.14.3", @@ -1743,12 +1784,12 @@ checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" [[package]] name = "is-terminal" -version = "0.4.10" +version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bad00257d07be169d870ab665980b06cdb366d792ad690bf2e76876dc503455" +checksum = "f23ff5ef2b80d608d61efee834934d862cd92461afc0560dedf493e4c033738b" dependencies = [ "hermit-abi", - "rustix", + "libc", "windows-sys 0.52.0", ] @@ -1804,26 +1845,27 @@ dependencies = [ [[package]] name = "jobserver" -version = "0.1.27" +version = "0.1.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c37f63953c4c63420ed5fd3d6d398c719489b9f872b9fa683262f8edd363c7d" +checksum = "ab46a6e9526ddef3ae7f787c06f0f2600639ba80ea3eade3d8e670a2230f51d6" dependencies = [ "libc", ] [[package]] name = "js-sys" -version = "0.3.67" +version = "0.3.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a1d36f1235bc969acba30b7f5990b864423a6068a10f7c90ae8f0112e3a59d1" +checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" dependencies = [ "wasm-bindgen", ] [[package]] -name = "jsonpath_lib" -version = "0.3.0" -source = "git+https://github.com/ritchie46/jsonpath?branch=improve_compiled#24eaf0b4416edff38a4d1b6b17bc4b9f3f047b4b" +name = "jsonpath_lib_polars_vendor" +version = "0.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4bd9354947622f7471ff713eacaabdb683ccb13bba4edccaab9860abf480b7d" dependencies = [ "log", "serde", @@ -1914,7 +1956,20 @@ checksum = "5ff4ae71b685bbad2f2f391fe74f6b7659a34871c08b210fdc039e43bee07d18" dependencies = [ "adler32", "crc32fast", - "libflate_lz77", + "libflate_lz77 1.2.0", +] + +[[package]] +name = "libflate" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f7d5654ae1795afc7ff76f4365c2c8791b0feb18e8996a96adad8ffd7c3b2bf" +dependencies = [ + "adler32", + "core2", + "crc32fast", + "dary_heap", + "libflate_lz77 2.0.0", ] [[package]] @@ -1926,6 +1981,17 @@ dependencies = [ "rle-decode-fast", ] +[[package]] +name = "libflate_lz77" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be5f52fb8c451576ec6b79d3f4deb327398bc05bbdbd99021a6e77a4c855d524" +dependencies = [ + "core2", + "hashbrown 0.13.2", + "rle-decode-fast", +] + [[package]] name = "libgit2-sys" version = "0.16.2+1.7.2" @@ -1940,12 +2006,12 @@ dependencies = [ [[package]] name = "libloading" -version = "0.8.1" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c571b676ddfc9a8c12f1f3d3085a7b163966a8fd8098a90640953ce5f6170161" +checksum = "0c2a198fb6b0eada2a8df47933734e6d35d350665a33a3593d7164fa52c75c19" dependencies = [ "cfg-if", - "windows-sys 0.48.0", + "windows-targets 0.52.4", ] [[package]] @@ -2004,9 +2070,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.20" +version = "0.4.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" +checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" [[package]] name = "lz4" @@ -2089,18 +2155,18 @@ checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" [[package]] name = "miniz_oxide" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7810e0be55b428ada41041c41f32c9f1a42817901b4ccf45fa3d4b6561e74c7" +checksum = "9d811f3e15f28568be3407c8e7fdb6514c1cda3cb30683f15b6a1a1dc4ea14a7" dependencies = [ "adler", ] [[package]] name = "mio" -version = "0.8.10" +version = "0.8.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09" +checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" dependencies = [ "libc", "wasi", @@ -2187,9 +2253,9 @@ dependencies = [ [[package]] name = "num-complex" -version = "0.4.4" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214" +checksum = "23c6602fda94a57c990fe0df199a035d83576b496aa29f4e634a8ac6004e68a6" dependencies = [ "num-traits", ] @@ -2202,19 +2268,18 @@ checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" [[package]] name = "num-integer" -version = "0.1.45" +version = "0.1.46" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" dependencies = [ - "autocfg", "num-traits", ] [[package]] name = "num-iter" -version = "0.1.43" +version = "0.1.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d03e6c028c5dc5cac6e2dec0efda81fc887605bb3d884578bb6d6bf7514e252" +checksum = "d869c01cc0c455284163fd0092f1f93835385ccab5a98a0dcc497b2f8bf055a9" dependencies = [ "autocfg", "num-integer", @@ -2235,9 +2300,9 @@ dependencies = [ [[package]] name = "num-traits" -version = "0.2.17" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" +checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a" dependencies = [ "autocfg", "libm", @@ -2279,9 +2344,9 @@ dependencies = [ [[package]] name = "object_store" -version = "0.9.0" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d139f545f64630e2e3688fd9f81c470888ab01edeb72d13b4e86c566f1130000" +checksum = "b8718f8b65fdf67a45108d1548347d4af7d71fb81ce727bbf9e3b2535e079db3" dependencies = [ "async-trait", "base64", @@ -2291,13 +2356,14 @@ dependencies = [ "humantime", "hyper", "itertools 0.12.1", + "md-5", "parking_lot", "percent-encoding", "quick-xml", "rand", "reqwest", "ring", - "rustls-pemfile 2.0.0", + "rustls-pemfile 2.1.1", "serde", "serde_json", "snafu", @@ -2452,9 +2518,9 @@ dependencies = [ [[package]] name = "pkg-config" -version = "0.3.29" +version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2900ede94e305130c13ddd391e0ab7cbaeb783945ae07a279c268cb05109c6cb" +checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" [[package]] name = "planus" @@ -2495,24 +2561,36 @@ dependencies = [ [[package]] name = "polars" -version = "0.37.0" +version = "0.38.2" dependencies = [ "ahash", + "apache-avro", + "avro-schema", + "either", + "ethnum", + "futures", "getrandom", + "polars-arrow", "polars-core", + "polars-error", "polars-io", "polars-lazy", "polars-ops", + "polars-parquet", "polars-plan", "polars-sql", "polars-time", + "polars-utils", + "proptest", "rand", + "tokio", + "tokio-util", "version_check", ] [[package]] name = "polars-arrow" -version = "0.37.0" +version = "0.38.2" dependencies = [ "ahash", "arrow-array", @@ -2554,7 +2632,7 @@ dependencies = [ "regex-syntax 0.8.2", "ryu", "sample-arrow2", - "sample-std", + "sample-std 0.1.1", "sample-test", "serde", "simdutf8", @@ -2580,7 +2658,7 @@ dependencies = [ [[package]] name = "polars-compute" -version = "0.37.0" +version = "0.38.2" dependencies = [ "bytemuck", "either", @@ -2588,13 +2666,14 @@ dependencies = [ "polars-arrow", "polars-error", "polars-utils", + "rand", "strength_reduce", "version_check", ] [[package]] name = "polars-core" -version = "0.37.0" +version = "0.38.2" dependencies = [ "ahash", "arrow-array", @@ -2629,7 +2708,7 @@ dependencies = [ [[package]] name = "polars-doc-examples" -version = "0.37.0" +version = "0.38.2" dependencies = [ "aws-config", "aws-sdk-s3", @@ -2642,7 +2721,7 @@ dependencies = [ [[package]] name = "polars-error" -version = "0.37.0" +version = "0.38.2" dependencies = [ "avro-schema", "object_store", @@ -2654,7 +2733,7 @@ dependencies = [ [[package]] name = "polars-ffi" -version = "0.37.0" +version = "0.38.2" dependencies = [ "polars-arrow", "polars-core", @@ -2662,7 +2741,7 @@ dependencies = [ [[package]] name = "polars-io" -version = "0.37.0" +version = "0.38.2" dependencies = [ "ahash", "async-trait", @@ -2706,7 +2785,7 @@ dependencies = [ [[package]] name = "polars-json" -version = "0.37.0" +version = "0.38.2" dependencies = [ "ahash", "chrono", @@ -2725,7 +2804,7 @@ dependencies = [ [[package]] name = "polars-lazy" -version = "0.37.0" +version = "0.38.2" dependencies = [ "ahash", "bitflags 2.4.2", @@ -2751,7 +2830,7 @@ dependencies = [ [[package]] name = "polars-ops" -version = "0.37.0" +version = "0.38.2" dependencies = [ "ahash", "aho-corasick", @@ -2764,7 +2843,7 @@ dependencies = [ "hashbrown 0.14.3", "hex", "indexmap", - "jsonpath_lib", + "jsonpath_lib_polars_vendor", "memchr", "num-traits", "polars-arrow", @@ -2786,7 +2865,7 @@ dependencies = [ [[package]] name = "polars-parquet" -version = "0.37.0" +version = "0.38.2" dependencies = [ "ahash", "async-stream", @@ -2813,7 +2892,7 @@ dependencies = [ [[package]] name = "polars-pipe" -version = "0.37.0" +version = "0.38.2" dependencies = [ "crossbeam-channel", "crossbeam-queue", @@ -2838,7 +2917,7 @@ dependencies = [ [[package]] name = "polars-plan" -version = "0.37.0" +version = "0.38.2" dependencies = [ "ahash", "bytemuck", @@ -2869,8 +2948,9 @@ dependencies = [ [[package]] name = "polars-row" -version = "0.37.0" +version = "0.38.2" dependencies = [ + "bytemuck", "polars-arrow", "polars-error", "polars-utils", @@ -2878,7 +2958,7 @@ dependencies = [ [[package]] name = "polars-sql" -version = "0.37.0" +version = "0.38.2" dependencies = [ "hex", "polars-arrow", @@ -2894,7 +2974,7 @@ dependencies = [ [[package]] name = "polars-time" -version = "0.37.0" +version = "0.38.2" dependencies = [ "atoi", "chrono", @@ -2913,7 +2993,7 @@ dependencies = [ [[package]] name = "polars-utils" -version = "0.37.0" +version = "0.38.2" dependencies = [ "ahash", "bytemuck", @@ -2922,12 +3002,20 @@ dependencies = [ "num-traits", "once_cell", "polars-error", + "rand", + "raw-cpuid", "rayon", "smartstring", "sysinfo", "version_check", ] +[[package]] +name = "portable-atomic" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" + [[package]] name = "powerfmt" version = "0.2.0" @@ -2990,7 +3078,7 @@ dependencies = [ [[package]] name = "py-polars" -version = "0.20.8" +version = "0.20.16-rc.1" dependencies = [ "ahash", "built", @@ -3021,9 +3109,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.20.2" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a89dc7a5850d0e983be1ec2a463a171d20990487c3cfcd68b5363f1ee3d6fe0" +checksum = "53bdbb96d49157e65d45cc287af5f32ffadd5f4761438b527b055fb0d4bb8233" dependencies = [ "cfg-if", "indoc", @@ -3031,6 +3119,7 @@ dependencies = [ "libc", "memoffset", "parking_lot", + "portable-atomic", "pyo3-build-config", "pyo3-ffi", "pyo3-macros", @@ -3039,9 +3128,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.20.2" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07426f0d8fe5a601f26293f300afd1a7b1ed5e78b2a705870c5f30893c5163be" +checksum = "deaa5745de3f5231ce10517a1f5dd97d53e5a2fd77aa6b5842292085831d48d7" dependencies = [ "once_cell", "target-lexicon", @@ -3055,9 +3144,9 @@ checksum = "be6d574e0f8cab2cdd1eeeb640cbf845c974519fa9e9b62fa9c08ecece0ca5de" [[package]] name = "pyo3-ffi" -version = "0.20.2" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbb7dec17e17766b46bca4f1a4215a85006b4c2ecde122076c562dd058da6cf1" +checksum = "62b42531d03e08d4ef1f6e85a2ed422eb678b8cd62b762e53891c05faf0d4afa" dependencies = [ "libc", "pyo3-build-config", @@ -3065,28 +3154,35 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.20.2" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05f738b4e40d50b5711957f142878cfa0f28e054aa0ebdfc3fd137a843f74ed3" +checksum = "7305c720fa01b8055ec95e484a6eca7a83c841267f0dd5280f0c8b8551d2c158" dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.48", + "syn 2.0.52", ] [[package]] name = "pyo3-macros-backend" -version = "0.20.2" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fc910d4851847827daf9d6cdd4a823fbdaab5b8818325c5e97a86da79e8881f" +checksum = "7c7e9b68bb9c3149c5b0cade5d07f953d6d125eb4337723c4ccdb665f1f96185" dependencies = [ "heck", "proc-macro2", + "pyo3-build-config", "quote", - "syn 2.0.48", + "syn 2.0.52", ] +[[package]] +name = "quad-rand" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "658fa1faf7a4cc5f057c9ee5ef560f717ad9d8dc66d975267f709624d6e1ab88" + [[package]] name = "quick-xml" version = "0.31.0" @@ -3176,6 +3272,15 @@ dependencies = [ "rand_core", ] +[[package]] +name = "raw-cpuid" +version = "11.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d86a7c4638d42c44551f4791a20e687dbb4c3de1f33c43dd71e355cd429def1" +dependencies = [ + "bitflags 2.4.2", +] + [[package]] name = "rawpointer" version = "0.2.1" @@ -3184,9 +3289,9 @@ checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" [[package]] name = "rayon" -version = "1.8.1" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa7237101a77a10773db45d62004a272517633fbcc3df19d96455ede1122e051" +checksum = "e4963ed1bc86e4f3ee217022bd855b297cef07fb9eac5dfa1f788b220b49b3bd" dependencies = [ "either", "rayon-core", @@ -3228,7 +3333,7 @@ checksum = "5fddb4f8d99b0a2ebafc65a87a69a7b9875e4b1ae1f00db265d300ef7f28bccc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.52", ] [[package]] @@ -3245,9 +3350,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5bb987efffd3c6d0d8f5f89510bb458559eab11e4f869acb20bf845e016259cd" +checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" dependencies = [ "aho-corasick", "memchr", @@ -3284,7 +3389,7 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http 0.2.11", + "http 0.2.12", "http-body", "hyper", "hyper-rustls", @@ -3328,16 +3433,17 @@ dependencies = [ [[package]] name = "ring" -version = "0.17.7" +version = "0.17.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "688c63d65483050968b2a8937f7995f443e27041a0f7700aa59b0822aedebb74" +checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" dependencies = [ "cc", + "cfg-if", "getrandom", "libc", "spin", "untrusted", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -3415,9 +3521,9 @@ dependencies = [ [[package]] name = "rustls-pemfile" -version = "2.0.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35e4980fa29e4c4b212ffb3db068a564cbf560e51d3944b7c88bd8bf5bec64f4" +checksum = "f48172685e6ff52a556baa527774f61fcaa884f59daf3375c62a3f1cd2549dab" dependencies = [ "base64", "rustls-pki-types", @@ -3425,9 +3531,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.1.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e9d979b3ce68192e42760c7810125eb6cf2ea10efae545a156063e61f314e2a" +checksum = "5ede67b28608b4c60685c7d54122d4400d90f62b40caee7700e700380a390fa8" [[package]] name = "rustls-webpki" @@ -3447,9 +3553,9 @@ checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4" [[package]] name = "ryu" -version = "1.0.16" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" +checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" [[package]] name = "same-file" @@ -3462,12 +3568,12 @@ dependencies = [ [[package]] name = "sample-arrow2" -version = "0.17.1" +version = "0.17.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "722050f91586506195398fd22d834eb8768716084f6ebf9f32b917ed422b6afb" +checksum = "502b30097ae5cc57ee8359bb59d8af349db022492de04596119d83f561ab8977" dependencies = [ "arrow2", - "sample-std", + "sample-std 0.2.1", ] [[package]] @@ -3483,6 +3589,19 @@ dependencies = [ "regex", ] +[[package]] +name = "sample-std" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "948bd219c6eb2b2ca1e004d8aefa8bbcf12614f60e0139b1758b49f9a94358c8" +dependencies = [ + "casey", + "quickcheck", + "rand", + "rand_regex", + "regex", +] + [[package]] name = "sample-test" version = "0.1.1" @@ -3490,7 +3609,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "713e500947ff19fc1ae2805afa33ef45f3bb2ec656c77d92252d24cf9e3091b2" dependencies = [ "quickcheck", - "sample-std", + "sample-std 0.1.1", "sample-test-macros", ] @@ -3502,7 +3621,7 @@ checksum = "df1a2c832a259aae95b6ed1da3aa377111ffde38d4282fa734faa3fff356534e" dependencies = [ "proc-macro2", "quote", - "sample-std", + "sample-std 0.1.1", "syn 1.0.109", ] @@ -3570,9 +3689,9 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.21" +version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b97ed7a9823b74f99c7742f5336af7be5ecd3eeafcb1507d1fa93347b1d589b0" +checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca" dependencies = [ "serde", ] @@ -3585,29 +3704,29 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.196" +version = "1.0.197" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "870026e60fa08c69f064aa766c10f10b1d62db9ccd4d0abb206472bee0ce3b32" +checksum = "3fb1c873e1b9b056a4dc4c0c198b24c3ffa059243875552b2bd0933b1aee4ce2" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.196" +version = "1.0.197" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33c85360c95e7d137454dc81d9a4ed2b8efd8fbe19cee57357b32b9771fccb67" +checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.52", ] [[package]] name = "serde_json" -version = "1.0.113" +version = "1.0.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69801b70b1c3dac963ecb03a364ba0ceda9cf60c71cfe475e99864759c8b8a79" +checksum = "c5f09b1bd632ef549eaa9f60a1f8de742bdbc698e6cee2095fc84dde5f549ae0" dependencies = [ "indexmap", "itoa", @@ -3764,12 +3883,12 @@ checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" [[package]] name = "socket2" -version = "0.5.5" +version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" +checksum = "05ffd9c0a93b7543e062e759284fcf5f5e3b098501104bfbdde4d404db792871" dependencies = [ "libc", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -3840,7 +3959,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.48", + "syn 2.0.52", ] [[package]] @@ -3862,9 +3981,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.48" +version = "2.0.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f" +checksum = "b699d15b36d1f02c3e7c69f8ffef53de37aefae075d8488d4ba1a7788d574a07" dependencies = [ "proc-macro2", "quote", @@ -3879,9 +3998,9 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" [[package]] name = "sysinfo" -version = "0.30.5" +version = "0.30.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fb4f3438c8f6389c864e61221cbc97e9bca98b4daf39a5beb7bea660f528bb2" +checksum = "0c385888ef380a852a16209afc8cfad22795dd8873d69c9a14d2e2088f118d18" dependencies = [ "cfg-if", "core-foundation-sys", @@ -3920,48 +4039,47 @@ checksum = "cfb5fa503293557c5158bd215fdc225695e567a77e453f5d4452a50a193969bd" [[package]] name = "target-lexicon" -version = "0.12.13" +version = "0.12.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69758bda2e78f098e4ccb393021a0963bb3442eac05f135c30f61b7370bbafae" +checksum = "e1fc403891a21bcfb7c37834ba66a547a8f402146eba7265b5a6d88059c9ff2f" [[package]] name = "tempfile" -version = "3.9.0" +version = "3.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01ce4141aa927a6d1bd34a041795abd0db1cccba5d5f24b009f694bdf3a1f3fa" +checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1" dependencies = [ "cfg-if", "fastrand", - "redox_syscall", "rustix", "windows-sys 0.52.0", ] [[package]] name = "thiserror" -version = "1.0.56" +version = "1.0.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d54378c645627613241d077a3a79db965db602882668f9136ac42af9ecb730ad" +checksum = "1e45bcbe8ed29775f228095caf2cd67af7a4ccf756ebff23a306bf3e8b47b24b" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.56" +version = "1.0.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa0faa943b50f3db30a20aa7e265dbc66076993efed8463e8de414e5d06d3471" +checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.52", ] [[package]] name = "time" -version = "0.3.32" +version = "0.3.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe80ced77cbfb4cb91a94bf72b378b4b6791a0d9b7f09d0be747d1bdff4e68bd" +checksum = "c8248b6521bb14bc45b4067159b9b6ad792e2d6d754d6c41fb50e29fefe38749" dependencies = [ "deranged", "num-conv", @@ -4023,9 +4141,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.35.1" +version = "1.36.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c89b4efa943be685f629b149f53829423f8f5531ea21249408e8e2f8671ec104" +checksum = "61285f6515fa018fb2d1e46eb21223fff441ee8db5d0f1435e8ab4f5cdb80931" dependencies = [ "backtrace", "bytes", @@ -4047,7 +4165,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.52", ] [[package]] @@ -4134,7 +4252,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.52", ] [[package]] @@ -4152,6 +4270,26 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "typed-builder" +version = "0.16.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34085c17941e36627a879208083e25d357243812c30e7d7387c3b954f30ade16" +dependencies = [ + "typed-builder-macro", +] + +[[package]] +name = "typed-builder-macro" +version = "0.16.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f03ca4cb38206e2bef0700092660bb74d696f808514dae47fa1467cbfe26e96e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.52", +] + [[package]] name = "typenum" version = "1.17.0" @@ -4178,9 +4316,9 @@ checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" [[package]] name = "unicode-normalization" -version = "0.1.22" +version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c5713f0fc4b5db668a2ac63cdb7bb4469d8c9fed047b1d0292cc7b0ce2ba921" +checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5" dependencies = [ "tinyvec", ] @@ -4196,9 +4334,9 @@ dependencies = [ [[package]] name = "unicode-segmentation" -version = "1.10.1" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36" +checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202" [[package]] name = "unicode-width" @@ -4242,6 +4380,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f00cc9702ca12d3c81455259621e676d0f7251cec66a21e98fe2e9a37db93b2a" dependencies = [ "getrandom", + "serde", ] [[package]] @@ -4276,9 +4415,9 @@ checksum = "5c3082ca00d5a5ef149bb8b555a72ae84c9c59f7250f013ac822ac2e49b19c64" [[package]] name = "walkdir" -version = "2.4.0" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d71d857dc86794ca4c280d616f7da00d2dbfd8cd788846559a6813e6aa4b54ee" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" dependencies = [ "same-file", "winapi-util", @@ -4301,9 +4440,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.90" +version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1223296a201415c7fad14792dbefaace9bd52b62d33453ade1c5b5f07555406" +checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -4311,24 +4450,24 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.90" +version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcdc935b63408d58a32f8cc9738a0bffd8f05cc7c002086c6ef20b7312ad9dcd" +checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" dependencies = [ "bumpalo", "log", "once_cell", "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.52", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.40" +version = "0.4.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bde2032aeb86bdfaecc8b261eef3cba735cc426c1f3a3416d1e0791be95fc461" +checksum = "76bc14366121efc8dbb487ab05bcc9d346b3b5ec0eaa76e46594cabbe51762c0" dependencies = [ "cfg-if", "js-sys", @@ -4338,9 +4477,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.90" +version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e4c238561b2d428924c49815533a8b9121c664599558a5d9ec51f8a1740a999" +checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -4348,22 +4487,22 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.90" +version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bae1abb6806dc1ad9e560ed242107c0f6c84335f1749dd4e8ddb012ebd5e25a7" +checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.52", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.90" +version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d91413b1c31d7539ba5ef2451af3f0b833a005eb27a631cec32bc0635a8602b" +checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" [[package]] name = "wasm-streams" @@ -4380,9 +4519,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.67" +version = "0.3.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58cd2333b6e0be7a39605f0e255892fd7418a682d8da8fe042fe25128794d2ed" +checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" dependencies = [ "js-sys", "wasm-bindgen", @@ -4426,7 +4565,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be" dependencies = [ "windows-core", - "windows-targets 0.52.0", + "windows-targets 0.52.4", ] [[package]] @@ -4435,7 +4574,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "windows-targets 0.52.0", + "windows-targets 0.52.4", ] [[package]] @@ -4453,7 +4592,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.52.0", + "windows-targets 0.52.4", ] [[package]] @@ -4473,17 +4612,17 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" +checksum = "7dd37b7e5ab9018759f893a1952c9420d060016fc19a472b4bb20d1bdd694d1b" dependencies = [ - "windows_aarch64_gnullvm 0.52.0", - "windows_aarch64_msvc 0.52.0", - "windows_i686_gnu 0.52.0", - "windows_i686_msvc 0.52.0", - "windows_x86_64_gnu 0.52.0", - "windows_x86_64_gnullvm 0.52.0", - "windows_x86_64_msvc 0.52.0", + "windows_aarch64_gnullvm 0.52.4", + "windows_aarch64_msvc 0.52.4", + "windows_i686_gnu 0.52.4", + "windows_i686_msvc 0.52.4", + "windows_x86_64_gnu 0.52.4", + "windows_x86_64_gnullvm 0.52.4", + "windows_x86_64_msvc 0.52.4", ] [[package]] @@ -4494,9 +4633,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" +checksum = "bcf46cf4c365c6f2d1cc93ce535f2c8b244591df96ceee75d8e83deb70a9cac9" [[package]] name = "windows_aarch64_msvc" @@ -4506,9 +4645,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" +checksum = "da9f259dd3bcf6990b55bffd094c4f7235817ba4ceebde8e6d11cd0c5633b675" [[package]] name = "windows_i686_gnu" @@ -4518,9 +4657,9 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" +checksum = "b474d8268f99e0995f25b9f095bc7434632601028cf86590aea5c8a5cb7801d3" [[package]] name = "windows_i686_msvc" @@ -4530,9 +4669,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" +checksum = "1515e9a29e5bed743cb4415a9ecf5dfca648ce85ee42e15873c3cd8610ff8e02" [[package]] name = "windows_x86_64_gnu" @@ -4542,9 +4681,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" +checksum = "5eee091590e89cc02ad514ffe3ead9eb6b660aedca2183455434b93546371a03" [[package]] name = "windows_x86_64_gnullvm" @@ -4554,9 +4693,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" +checksum = "77ca79f2451b49fa9e2af39f0747fe999fcda4f5e241b2898624dca97a1f2177" [[package]] name = "windows_x86_64_msvc" @@ -4566,15 +4705,15 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.52.0" +version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" +checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8" [[package]] name = "winnow" -version = "0.5.36" +version = "0.5.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "818ce546a11a9986bc24f93d0cdf38a8a1a400f1473ea8c82e59f6e0ffab9249" +checksum = "f593a95398737aeed53e489c785df13f3618e41dbcd6718c6addbf1395aa6876" dependencies = [ "memchr", ] @@ -4597,9 +4736,9 @@ checksum = "66fee0b777b0f5ac1c69bb06d361268faafa61cd4682ae064a171c16c433e9e4" [[package]] name = "xxhash-rust" -version = "0.8.8" +version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53be06678ed9e83edb1745eb72efc0bbcd7b5c3c35711a860906aed827a13d61" +checksum = "927da81e25be1e1a2901d59b81b37dd2efd1fc9c9345a55007f09bf5a2d3ee03" [[package]] name = "zerocopy" @@ -4618,7 +4757,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.52", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 80c173e21a6da..bf71d65c89d9e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ default-members = [ # ] [workspace.package] -version = "0.37.0" +version = "0.38.2" authors = ["Ritchie Vink "] edition = "2021" homepage = "https://www.pola.rs/" @@ -60,7 +60,8 @@ percent-encoding = "2.3" pyo3 = "0.20" rand = "0.8" rand_distr = "0.4" -rayon = "1.8" +raw-cpuid = "11" +rayon = "1.9" regex = "1.9" reqwest = { version = "0.11", default-features = false } ryu = "1.0.13" @@ -83,22 +84,22 @@ xxhash-rust = { version = "0.8.6", features = ["xxh3"] } zstd = "0.13" uuid = { version = "1.7.0", features = ["v4"] } -polars = { version = "0.37.0", path = "crates/polars", default-features = false } -polars-compute = { version = "0.37.0", path = "crates/polars-compute", default-features = false } -polars-core = { version = "0.37.0", path = "crates/polars-core", default-features = false } -polars-error = { version = "0.37.0", path = "crates/polars-error", default-features = false } -polars-ffi = { version = "0.37.0", path = "crates/polars-ffi", default-features = false } -polars-io = { version = "0.37.0", path = "crates/polars-io", default-features = false } -polars-json = { version = "0.37.0", path = "crates/polars-json", default-features = false } -polars-lazy = { version = "0.37.0", path = "crates/polars-lazy", default-features = false } -polars-ops = { version = "0.37.0", path = "crates/polars-ops", default-features = false } -polars-parquet = { version = "0.37.0", path = "crates/polars-parquet", default-features = false } -polars-pipe = { version = "0.37.0", path = "crates/polars-pipe", default-features = false } -polars-plan = { version = "0.37.0", path = "crates/polars-plan", default-features = false } -polars-row = { version = "0.37.0", path = "crates/polars-row", default-features = false } -polars-sql = { version = "0.37.0", path = "crates/polars-sql", default-features = false } -polars-time = { version = "0.37.0", path = "crates/polars-time", default-features = false } -polars-utils = { version = "0.37.0", path = "crates/polars-utils", default-features = false } +polars = { version = "0.38.2", path = "crates/polars", default-features = false } +polars-compute = { version = "0.38.2", path = "crates/polars-compute", default-features = false } +polars-core = { version = "0.38.2", path = "crates/polars-core", default-features = false } +polars-error = { version = "0.38.2", path = "crates/polars-error", default-features = false } +polars-ffi = { version = "0.38.2", path = "crates/polars-ffi", default-features = false } +polars-io = { version = "0.38.2", path = "crates/polars-io", default-features = false } +polars-json = { version = "0.38.2", path = "crates/polars-json", default-features = false } +polars-lazy = { version = "0.38.2", path = "crates/polars-lazy", default-features = false } +polars-ops = { version = "0.38.2", path = "crates/polars-ops", default-features = false } +polars-parquet = { version = "0.38.2", path = "crates/polars-parquet", default-features = false } +polars-pipe = { version = "0.38.2", path = "crates/polars-pipe", default-features = false } +polars-plan = { version = "0.38.2", path = "crates/polars-plan", default-features = false } +polars-row = { version = "0.38.2", path = "crates/polars-row", default-features = false } +polars-sql = { version = "0.38.2", path = "crates/polars-sql", default-features = false } +polars-time = { version = "0.38.2", path = "crates/polars-time", default-features = false } +polars-utils = { version = "0.38.2", path = "crates/polars-utils", default-features = false } [workspace.dependencies.arrow-format] package = "polars-arrow-format" @@ -106,7 +107,7 @@ version = "0.1.0" [workspace.dependencies.arrow] package = "polars-arrow" -version = "0.37.0" +version = "0.38.2" path = "crates/polars-arrow" default-features = false features = [ diff --git a/Makefile b/Makefile index da9e0a06bf019..e68a944367518 100644 --- a/Makefile +++ b/Makefile @@ -19,64 +19,65 @@ FILTER_PIP_WARNINGS=| grep -v "don't match your environment"; test $${PIPESTATUS .PHONY: requirements requirements: .venv ## Install/refresh Python project requirements - $(VENV_BIN)/python -m pip install --upgrade pip - $(VENV_BIN)/pip install --upgrade -r py-polars/requirements-dev.txt - $(VENV_BIN)/pip install --upgrade -r py-polars/requirements-lint.txt - $(VENV_BIN)/pip install --upgrade -r py-polars/docs/requirements-docs.txt - $(VENV_BIN)/pip install --upgrade -r docs/requirements.txt + @unset CONDA_PREFIX \ + && $(VENV_BIN)/python -m pip install --upgrade uv \ + && $(VENV_BIN)/uv pip install --upgrade -r py-polars/requirements-dev.txt \ + && $(VENV_BIN)/uv pip install --upgrade -r py-polars/requirements-lint.txt \ + && $(VENV_BIN)/uv pip install --upgrade -r py-polars/docs/requirements-docs.txt \ + && $(VENV_BIN)/uv pip install --upgrade -r docs/requirements.txt .PHONY: build build: .venv ## Compile and install Python Polars for development - @unset CONDA_PREFIX && source $(VENV_BIN)/activate \ - && maturin develop -m py-polars/Cargo.toml \ + @unset CONDA_PREFIX \ + && $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml \ $(FILTER_PIP_WARNINGS) .PHONY: build-debug-opt build-debug-opt: .venv ## Compile and install Python Polars with minimal optimizations turned on - @unset CONDA_PREFIX && source $(VENV_BIN)/activate \ - && maturin develop -m py-polars/Cargo.toml --profile opt-dev \ + @unset CONDA_PREFIX \ + && $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --profile opt-dev \ $(FILTER_PIP_WARNINGS) .PHONY: build-debug-opt-subset build-debug-opt-subset: .venv ## Compile and install Python Polars with minimal optimizations turned on and no default features - @unset CONDA_PREFIX && source $(VENV_BIN)/activate \ - && maturin develop -m py-polars/Cargo.toml --no-default-features --profile opt-dev \ + @unset CONDA_PREFIX \ + && $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --no-default-features --profile opt-dev \ $(FILTER_PIP_WARNINGS) .PHONY: build-opt build-opt: .venv ## Compile and install Python Polars with nearly full optimization on and debug assertions turned off, but with debug symbols on - @unset CONDA_PREFIX && source $(VENV_BIN)/activate \ - && maturin develop -m py-polars/Cargo.toml --profile debug-release \ + @unset CONDA_PREFIX \ + && $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --profile debug-release \ $(FILTER_PIP_WARNINGS) .PHONY: build-release build-release: .venv ## Compile and install a faster Python Polars binary with full optimizations - @unset CONDA_PREFIX && source $(VENV_BIN)/activate \ - && maturin develop -m py-polars/Cargo.toml --release \ + @unset CONDA_PREFIX \ + && $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --release \ $(FILTER_PIP_WARNINGS) .PHONY: build-native build-native: .venv ## Same as build, except with native CPU optimizations turned on - @unset CONDA_PREFIX && source $(VENV_BIN)/activate \ - && maturin develop -m py-polars/Cargo.toml -- -C target-cpu=native \ + @unset CONDA_PREFIX && RUSTFLAGS='-C target-cpu=native' \ + $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml \ $(FILTER_PIP_WARNINGS) .PHONY: build-debug-opt-native build-debug-opt-native: .venv ## Same as build-debug-opt, except with native CPU optimizations turned on - @unset CONDA_PREFIX && source $(VENV_BIN)/activate \ - && maturin develop -m py-polars/Cargo.toml --profile opt-dev -- -C target-cpu=native \ + @unset CONDA_PREFIX && RUSTFLAGS='-C target-cpu=native' \ + $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --profile opt-dev \ $(FILTER_PIP_WARNINGS) .PHONY: build-opt-native build-opt-native: .venv ## Same as build-opt, except with native CPU optimizations turned on - @unset CONDA_PREFIX && source $(VENV_BIN)/activate \ - && maturin develop -m py-polars/Cargo.toml --profile debug-release -- -C target-cpu=native \ + @unset CONDA_PREFIX && RUSTFLAGS='-C target-cpu=native' \ + $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --profile debug-release \ $(FILTER_PIP_WARNINGS) .PHONY: build-release-native build-release-native: .venv ## Same as build-release, except with native CPU optimizations turned on - @unset CONDA_PREFIX && source $(VENV_BIN)/activate \ - && maturin develop -m py-polars/Cargo.toml --release -- -C target-cpu=native \ + @unset CONDA_PREFIX && RUSTFLAGS='-C target-cpu=native' \ + $(VENV_BIN)/maturin develop -m py-polars/Cargo.toml --release \ $(FILTER_PIP_WARNINGS) diff --git a/README.md b/README.md index 2f7fbf20c38ef..bfacbfa774e0e 100644 --- a/README.md +++ b/README.md @@ -228,7 +228,7 @@ Requires Rust version `>=1.71`. ## Contributing -Want to contribute? Read our [contribution guideline](/CONTRIBUTING.md). +Want to contribute? Read our [contributing guide](https://docs.pola.rs/development/contributing/). ## Python: compile Polars from source diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 0000000000000..65df338d67b49 --- /dev/null +++ b/codecov.yml @@ -0,0 +1,19 @@ +coverage: + status: + project: off + patch: off +ignore: + - crates/polars-arrow/src/io/flight/*.rs + - crates/polars-arrow/src/io/ipc/append/*.rs + - crates/polars-arrow/src/io/ipc/read/array/union.rs + - crates/polars-arrow/src/io/ipc/read/array/map.rs + - crates/polars-arrow/src/io/ipc/read/array/binary.rs + - crates/polars-arrow/src/io/ipc/read/array/fixed_size_binary.rs + - crates/polars-arrow/src/io/ipc/read/array/null.rs + - crates/polars-arrow/src/io/ipc/write/serialize/fixed_size_binary.rs + - crates/polars-arrow/src/io/ipc/write/serialize/union.rs + - crates/polars-arrow/src/io/ipc/write/serialize/map.rs + - crates/polars-arrow/src/array/union/*.rs + - crates/polars-arrow/src/array/map/*.rs + - crates/polars-arrow/src/array/fixed_size_binary/*.rs + diff --git a/crates/Makefile b/crates/Makefile index 6e4ded353458b..e344ceba4c502 100644 --- a/crates/Makefile +++ b/crates/Makefile @@ -42,6 +42,7 @@ miri: ## Run miri .PHONY: test test: ## Run tests cargo test --all-features \ + -p polars-compute \ -p polars-core \ -p polars-io \ -p polars-lazy \ @@ -57,6 +58,7 @@ test: ## Run tests .PHONY: nextest nextest: ## Run tests with nextest cargo nextest run --all-features \ + -p polars-compute \ -p polars-core \ -p polars-io \ -p polars-lazy \ diff --git a/crates/polars-arrow/Cargo.toml b/crates/polars-arrow/Cargo.toml index f19509b7b96a1..5f7ce63ad0380 100644 --- a/crates/polars-arrow/Cargo.toml +++ b/crates/polars-arrow/Cargo.toml @@ -159,7 +159,9 @@ compute = [ simd = [] # polars-arrow -timezones = [] +timezones = [ + "chrono-tz", +] dtype-array = [] dtype-decimal = ["atoi", "itoap"] bigidx = [] diff --git a/crates/polars-arrow/src/array/binary/from.rs b/crates/polars-arrow/src/array/binary/from.rs index 73df03531594b..9ffac9827bb82 100644 --- a/crates/polars-arrow/src/array/binary/from.rs +++ b/crates/polars-arrow/src/array/binary/from.rs @@ -1,5 +1,3 @@ -use std::iter::FromIterator; - use super::{BinaryArray, MutableBinaryArray}; use crate::offset::Offset; diff --git a/crates/polars-arrow/src/array/binary/mod.rs b/crates/polars-arrow/src/array/binary/mod.rs index 34d00e129c934..a0ba77030a831 100644 --- a/crates/polars-arrow/src/array/binary/mod.rs +++ b/crates/polars-arrow/src/array/binary/mod.rs @@ -161,6 +161,7 @@ impl BinaryArray { } /// Returns the element at index `i` + /// /// # Safety /// Assumes that the `i < self.len`. #[inline] @@ -225,6 +226,7 @@ impl BinaryArray { /// Slices this [`BinaryArray`]. /// # Implementation /// This function is `O(1)`. + /// /// # Safety /// The caller must ensure that `offset + length <= self.len()`. pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { @@ -374,6 +376,7 @@ impl BinaryArray { } /// Creates a [`BinaryArray`] from an iterator of trusted length. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. @@ -398,6 +401,7 @@ impl BinaryArray { } /// Creates a [`BinaryArray`] from an falible iterator of trusted length. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. diff --git a/crates/polars-arrow/src/array/binary/mutable.rs b/crates/polars-arrow/src/array/binary/mutable.rs index 1b0259dad476a..53a8ed32bb6f6 100644 --- a/crates/polars-arrow/src/array/binary/mutable.rs +++ b/crates/polars-arrow/src/array/binary/mutable.rs @@ -1,4 +1,3 @@ -use std::iter::FromIterator; use std::sync::Arc; use polars_error::{polars_bail, PolarsResult}; @@ -236,6 +235,7 @@ impl> FromIterator> for MutableBinaryArray MutableBinaryArray { /// Creates a [`MutableBinaryArray`] from an iterator of trusted length. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. @@ -262,6 +262,7 @@ impl MutableBinaryArray { } /// Creates a new [`BinaryArray`] from a [`TrustedLen`] of `&[u8]`. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. @@ -283,6 +284,7 @@ impl MutableBinaryArray { } /// Creates a [`MutableBinaryArray`] from an falible iterator of trusted length. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. @@ -349,6 +351,7 @@ impl MutableBinaryArray { /// Extends the [`MutableBinaryArray`] from an `iterator` of values of trusted length. /// This differs from `extend_trusted_len_unchecked` which accepts iterator of optional /// values. + /// /// # Safety /// The `iterator` must be [`TrustedLen`] #[inline] @@ -378,6 +381,7 @@ impl MutableBinaryArray { } /// Extends the [`MutableBinaryArray`] from an iterator of [`TrustedLen`] + /// /// # Safety /// The `iterator` must be [`TrustedLen`] #[inline] diff --git a/crates/polars-arrow/src/array/binary/mutable_values.rs b/crates/polars-arrow/src/array/binary/mutable_values.rs index d6c8c969f0587..613cbb0aba9e8 100644 --- a/crates/polars-arrow/src/array/binary/mutable_values.rs +++ b/crates/polars-arrow/src/array/binary/mutable_values.rs @@ -1,4 +1,3 @@ -use std::iter::FromIterator; use std::sync::Arc; use polars_error::{polars_bail, PolarsResult}; @@ -163,6 +162,7 @@ impl MutableBinaryValuesArray { } /// Returns the value of the element at index `i`. + /// /// # Safety /// This function is safe iff `i < self.len`. #[inline] @@ -266,6 +266,7 @@ impl MutableBinaryValuesArray { } /// Extends [`MutableBinaryValuesArray`] from an iterator of trusted len. + /// /// # Safety /// The iterator must be trusted len. #[inline] @@ -289,6 +290,7 @@ impl MutableBinaryValuesArray { } /// Returns a new [`MutableBinaryValuesArray`] from an iterator of trusted length. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. diff --git a/crates/polars-arrow/src/array/binview/mod.rs b/crates/polars-arrow/src/array/binview/mod.rs index 89216f9a3b749..63e1e1beea817 100644 --- a/crates/polars-arrow/src/array/binview/mod.rs +++ b/crates/polars-arrow/src/array/binview/mod.rs @@ -26,6 +26,7 @@ mod private { } pub use iterator::BinaryViewValueIter; pub use mutable::MutableBinaryViewArray; +use polars_utils::slice::GetSaferUnchecked; use private::Sealed; use crate::array::binview::view::{validate_binary_view, validate_utf8_only, validate_utf8_view}; @@ -33,7 +34,7 @@ use crate::array::iterator::NonNullValuesIter; use crate::bitmap::utils::{BitmapIter, ZipValidity}; pub type BinaryViewArray = BinaryViewArrayGeneric<[u8]>; pub type Utf8ViewArray = BinaryViewArrayGeneric; -pub use view::View; +pub use view::{View, INLINE_VIEW_SIZE}; pub type MutablePlString = MutableBinaryViewArray; pub type MutablePlBinary = MutableBinaryViewArray<[u8]>; @@ -178,6 +179,7 @@ impl BinaryViewArrayGeneric { } /// Create a new BinaryViewArray but initialize a statistics compute. + /// /// # Safety /// The caller must ensure the invariants pub unsafe fn new_unchecked_unknown_md( @@ -267,11 +269,12 @@ impl BinaryViewArrayGeneric { } /// Returns the element at index `i` + /// /// # Safety /// Assumes that the `i < self.len`. #[inline] pub unsafe fn value_unchecked(&self, i: usize) -> &T { - let v = *self.views.get_unchecked(i); + let v = *self.views.get_unchecked_release(i); let len = v.length; // view layout: @@ -288,10 +291,12 @@ impl BinaryViewArrayGeneric { let ptr = self.views.as_ptr() as *const u8; std::slice::from_raw_parts(ptr.add(i * 16 + 4), len as usize) } else { - let (data_ptr, data_len) = *self.raw_buffers.get_unchecked(v.buffer_idx as usize); + let (data_ptr, data_len) = *self + .raw_buffers + .get_unchecked_release(v.buffer_idx as usize); let data = std::slice::from_raw_parts(data_ptr, data_len); let offset = v.offset as usize; - data.get_unchecked(offset..offset + len as usize) + data.get_unchecked_release(offset..offset + len as usize) }; T::from_bytes_unchecked(bytes) } @@ -422,7 +427,7 @@ impl BinaryViewArray { /// Validate the underlying bytes on UTF-8. pub fn validate_utf8(&self) -> PolarsResult<()> { // SAFETY: views are correct - unsafe { validate_utf8_only(&self.views, &self.buffers) } + unsafe { validate_utf8_only(&self.views, &self.buffers, &self.buffers) } } /// Convert [`BinaryViewArray`] to [`Utf8ViewArray`]. diff --git a/crates/polars-arrow/src/array/binview/mutable.rs b/crates/polars-arrow/src/array/binview/mutable.rs index 4d62ff592c87c..467a6a8785d31 100644 --- a/crates/polars-arrow/src/array/binview/mutable.rs +++ b/crates/polars-arrow/src/array/binview/mutable.rs @@ -8,7 +8,7 @@ use polars_utils::slice::GetSaferUnchecked; use crate::array::binview::iterator::MutableBinaryViewValueIter; use crate::array::binview::view::validate_utf8_only; use crate::array::binview::{BinaryViewArrayGeneric, ViewType}; -use crate::array::{Array, MutableArray, View}; +use crate::array::{Array, MutableArray, TryExtend, TryPush, View}; use crate::bitmap::MutableBitmap; use crate::buffer::Buffer; use crate::datatypes::ArrowDataType; @@ -99,6 +99,11 @@ impl MutableBinaryViewArray { &self.views } + #[inline] + pub fn completed_buffers(&self) -> &[Buffer] { + &self.completed_buffers + } + pub fn validity(&mut self) -> Option<&mut MutableBitmap> { self.validity.as_mut() } @@ -308,10 +313,13 @@ impl MutableBinaryViewArray { Self::from_iterator(slice.as_ref().iter().map(|opt_v| opt_v.as_ref())) } - fn finish_in_progress(&mut self) { + fn finish_in_progress(&mut self) -> bool { if !self.in_progress_buffer.is_empty() { self.completed_buffers .push(std::mem::take(&mut self.in_progress_buffer).into()); + true + } else { + false } } @@ -320,7 +328,14 @@ impl MutableBinaryViewArray { self.into() } + #[inline] + pub fn value(&self, i: usize) -> &T { + assert!(i < self.len()); + unsafe { self.value_unchecked(i) } + } + /// Returns the element at index `i` + /// /// # Safety /// Assumes that the `i < self.len`. #[inline] @@ -363,10 +378,22 @@ impl MutableBinaryViewArray { } impl MutableBinaryViewArray<[u8]> { - pub fn validate_utf8(&mut self) -> PolarsResult<()> { - self.finish_in_progress(); + pub fn validate_utf8(&mut self, buffer_offset: usize, views_offset: usize) -> PolarsResult<()> { + // Finish the in progress as it might be required for validation. + let pushed = self.finish_in_progress(); // views are correct - unsafe { validate_utf8_only(&self.views, &self.completed_buffers) } + unsafe { + validate_utf8_only( + &self.views[views_offset..], + &self.completed_buffers[buffer_offset..], + &self.completed_buffers, + )? + } + // Restore in-progress buffer as we don't want to get too small buffers + if let (true, Some(last)) = (pushed, self.completed_buffers.pop()) { + self.in_progress_buffer = last.into_mut().right().unwrap(); + } + Ok(()) } } @@ -423,3 +450,21 @@ impl MutableArray for MutableBinaryViewArray { self.views.shrink_to_fit() } } + +impl> TryExtend> for MutableBinaryViewArray { + /// This is infallible and is implemented for consistency with all other types + #[inline] + fn try_extend>>(&mut self, iter: I) -> PolarsResult<()> { + self.extend(iter.into_iter()); + Ok(()) + } +} + +impl> TryPush> for MutableBinaryViewArray { + /// This is infallible and is implemented for consistency with all other types + #[inline(always)] + fn try_push(&mut self, item: Option

) -> PolarsResult<()> { + self.push(item.as_ref().map(|p| p.as_ref())); + Ok(()) + } +} diff --git a/crates/polars-arrow/src/array/binview/view.rs b/crates/polars-arrow/src/array/binview/view.rs index 34e7d799d3ea6..4c480afdc5c56 100644 --- a/crates/polars-arrow/src/array/binview/view.rs +++ b/crates/polars-arrow/src/array/binview/view.rs @@ -13,6 +13,8 @@ use crate::buffer::Buffer; use crate::datatypes::PrimitiveType; use crate::types::NativeType; +pub const INLINE_VIEW_SIZE: u32 = 12; + // We use this instead of u128 because we want alignment of <= 8 bytes. #[derive(Debug, Copy, Clone, Default)] #[repr(C)] @@ -148,8 +150,8 @@ where { for view in views { let len = view.length; - if len <= 12 { - if len < 12 && view.as_u128() >> (32 + len * 8) != 0 { + if len <= INLINE_VIEW_SIZE { + if len < INLINE_VIEW_SIZE && view.as_u128() >> (32 + len * 8) != 0 { polars_bail!(ComputeError: "view contained non-zero padding in prefix"); } @@ -193,25 +195,51 @@ pub(super) fn validate_utf8_view(views: &[View], buffers: &[Buffer]) -> Pola /// The views and buffers must uphold the invariants of BinaryView otherwise we will go OOB. pub(super) unsafe fn validate_utf8_only( views: &[View], - buffers: &[Buffer], + buffers_to_check: &[Buffer], + all_buffers: &[Buffer], ) -> PolarsResult<()> { - for view in views { - let len = view.length; - if len <= 12 { + // If we have no buffers, we don't have to branch. + if all_buffers.is_empty() { + for view in views { + let len = view.length; validate_utf8( view.to_le_bytes() .get_unchecked_release(4..4 + len as usize), )?; - } else { - let buffer_idx = view.buffer_idx; - let offset = view.offset; - let data = buffers.get_unchecked_release(buffer_idx as usize); - - let start = offset as usize; - let end = start + len as usize; - let b = &data.as_slice().get_unchecked_release(start..end); - validate_utf8(b)?; - }; + } + return Ok(()); + } + + // Fast path if all buffers are ascii + if buffers_to_check.iter().all(|buf| buf.is_ascii()) { + for view in views { + let len = view.length; + if len <= 12 { + validate_utf8( + view.to_le_bytes() + .get_unchecked_release(4..4 + len as usize), + )?; + } + } + } else { + for view in views { + let len = view.length; + if len <= 12 { + validate_utf8( + view.to_le_bytes() + .get_unchecked_release(4..4 + len as usize), + )?; + } else { + let buffer_idx = view.buffer_idx; + let offset = view.offset; + let data = all_buffers.get_unchecked_release(buffer_idx as usize); + + let start = offset as usize; + let end = start + len as usize; + let b = &data.as_slice().get_unchecked_release(start..end); + validate_utf8(b)?; + }; + } } Ok(()) diff --git a/crates/polars-arrow/src/array/boolean/from.rs b/crates/polars-arrow/src/array/boolean/from.rs index 81a5395ccc069..07553d78b7374 100644 --- a/crates/polars-arrow/src/array/boolean/from.rs +++ b/crates/polars-arrow/src/array/boolean/from.rs @@ -1,5 +1,3 @@ -use std::iter::FromIterator; - use super::{BooleanArray, MutableBooleanArray}; impl]>> From

for BooleanArray { diff --git a/crates/polars-arrow/src/array/boolean/mod.rs b/crates/polars-arrow/src/array/boolean/mod.rs index eb3401ee8d835..0737dcfece55d 100644 --- a/crates/polars-arrow/src/array/boolean/mod.rs +++ b/crates/polars-arrow/src/array/boolean/mod.rs @@ -136,6 +136,7 @@ impl BooleanArray { } /// Returns the element at index `i` as bool + /// /// # Safety /// Caller must be sure that `i < self.len()` #[inline] @@ -173,6 +174,7 @@ impl BooleanArray { /// Slices this [`BooleanArray`]. /// # Implementation /// This operation is `O(1)` as it amounts to increase two ref counts. + /// /// # Safety /// The caller must ensure that `offset + length <= self.len()`. #[inline] @@ -279,6 +281,7 @@ impl BooleanArray { /// Creates a new [`BooleanArray`] from an [`TrustedLen`] of `bool`. /// Use this over [`BooleanArray::from_trusted_len_iter`] when the iterator is trusted len /// but this crate does not mark it as such. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. @@ -298,6 +301,7 @@ impl BooleanArray { /// Creates a [`BooleanArray`] from an iterator of trusted length. /// Use this over [`BooleanArray::from_trusted_len_iter`] when the iterator is trusted len /// but this crate does not mark it as such. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. @@ -321,6 +325,7 @@ impl BooleanArray { } /// Creates a [`BooleanArray`] from an falible iterator of trusted length. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. diff --git a/crates/polars-arrow/src/array/boolean/mutable.rs b/crates/polars-arrow/src/array/boolean/mutable.rs index 85f2c6177ba02..fd3b4a1989d79 100644 --- a/crates/polars-arrow/src/array/boolean/mutable.rs +++ b/crates/polars-arrow/src/array/boolean/mutable.rs @@ -1,4 +1,3 @@ -use std::iter::FromIterator; use std::sync::Arc; use polars_error::{polars_bail, PolarsResult}; @@ -142,6 +141,7 @@ impl MutableBooleanArray { /// Extends the [`MutableBooleanArray`] from an iterator of values of trusted len. /// This differs from `extend_trusted_len_unchecked`, which accepts in iterator of optional values. + /// /// # Safety /// The iterator must be trusted len. #[inline] @@ -172,6 +172,7 @@ impl MutableBooleanArray { } /// Extends the [`MutableBooleanArray`] from an iterator of trusted len. + /// /// # Safety /// The iterator must be trusted len. #[inline] @@ -255,6 +256,7 @@ impl MutableBooleanArray { /// Creates a new [`MutableBooleanArray`] from an [`TrustedLen`] of `bool`. /// Use this over [`BooleanArray::from_trusted_len_iter`] when the iterator is trusted len /// but this crate does not mark it as such. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. @@ -276,6 +278,7 @@ impl MutableBooleanArray { /// Creates a [`BooleanArray`] from an iterator of trusted length. /// Use this over [`BooleanArray::from_trusted_len_iter`] when the iterator is trusted len /// but this crate does not mark it as such. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. @@ -302,6 +305,7 @@ impl MutableBooleanArray { } /// Creates a [`BooleanArray`] from an falible iterator of trusted length. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. diff --git a/crates/polars-arrow/src/array/dictionary/mod.rs b/crates/polars-arrow/src/array/dictionary/mod.rs index 8794ac20aff7e..6947c9e071c7a 100644 --- a/crates/polars-arrow/src/array/dictionary/mod.rs +++ b/crates/polars-arrow/src/array/dictionary/mod.rs @@ -37,6 +37,7 @@ pub unsafe trait DictionaryKey: NativeType + TryInto + TryFrom + H const KEY_TYPE: IntegerType; /// Represents this key as a `usize`. + /// /// # Safety /// The caller _must_ have checked that the value can be casted to `usize`. #[inline] @@ -178,6 +179,7 @@ impl DictionaryArray { /// * the `data_type`'s logical type is not a `DictionaryArray` /// * the `data_type`'s keys is not compatible with `keys` /// * the `data_type`'s values's data_type is not equal with `values.data_type()` + /// /// # Safety /// The caller must ensure that every keys's values is represented in `usize` and is `< values.len()` pub unsafe fn try_new_unchecked( @@ -292,6 +294,7 @@ impl DictionaryArray { } /// Slices this [`DictionaryArray`]. + /// /// # Safety /// Safe iff `offset + length <= self.len()`. pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { diff --git a/crates/polars-arrow/src/array/ffi.rs b/crates/polars-arrow/src/array/ffi.rs index e1dd62488b70e..9806eac25e972 100644 --- a/crates/polars-arrow/src/array/ffi.rs +++ b/crates/polars-arrow/src/array/ffi.rs @@ -26,6 +26,7 @@ pub(crate) unsafe trait ToFfi { /// [C data interface](https://arrow.apache.org/docs/format/CDataInterface.html) (FFI). pub(crate) trait FromFfi: Sized { /// Convert itself from FFI. + /// /// # Safety /// This function is intrinsically `unsafe` as it requires the FFI to be made according /// to the [C data interface](https://arrow.apache.org/docs/format/CDataInterface.html) diff --git a/crates/polars-arrow/src/array/fixed_size_binary/mod.rs b/crates/polars-arrow/src/array/fixed_size_binary/mod.rs index 70e421e746ad2..e439aac214aaa 100644 --- a/crates/polars-arrow/src/array/fixed_size_binary/mod.rs +++ b/crates/polars-arrow/src/array/fixed_size_binary/mod.rs @@ -105,6 +105,7 @@ impl FixedSizeBinaryArray { /// Slices this [`FixedSizeBinaryArray`]. /// # Implementation /// This operation is `O(1)`. + /// /// # Safety /// The caller must ensure that `offset + length <= self.len()`. pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { @@ -151,6 +152,7 @@ impl FixedSizeBinaryArray { } /// Returns the element at index `i` as &str + /// /// # Safety /// Assumes that the `i < self.len`. #[inline] @@ -264,20 +266,3 @@ impl FixedSizeBinaryArray { MutableFixedSizeBinaryArray::from(slice).into() } } - -pub trait FixedSizeBinaryValues { - fn values(&self) -> &[u8]; - fn size(&self) -> usize; -} - -impl FixedSizeBinaryValues for FixedSizeBinaryArray { - #[inline] - fn values(&self) -> &[u8] { - &self.values - } - - #[inline] - fn size(&self) -> usize { - self.size - } -} diff --git a/crates/polars-arrow/src/array/fixed_size_binary/mutable.rs b/crates/polars-arrow/src/array/fixed_size_binary/mutable.rs index aba3904bef4eb..8f81ce86f6d85 100644 --- a/crates/polars-arrow/src/array/fixed_size_binary/mutable.rs +++ b/crates/polars-arrow/src/array/fixed_size_binary/mutable.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use polars_error::{polars_bail, PolarsResult}; -use super::{FixedSizeBinaryArray, FixedSizeBinaryValues}; +use super::FixedSizeBinaryArray; use crate::array::physical_binary::extend_validity; use crate::array::{Array, MutableArray, TryExtendFromSelf}; use crate::bitmap::MutableBitmap; @@ -200,6 +200,7 @@ impl MutableFixedSizeBinaryArray { } /// Returns the element at index `i` as `&[u8]` + /// /// # Safety /// Assumes that the `i < self.len`. #[inline] @@ -289,18 +290,6 @@ impl MutableArray for MutableFixedSizeBinaryArray { } } -impl FixedSizeBinaryValues for MutableFixedSizeBinaryArray { - #[inline] - fn values(&self) -> &[u8] { - &self.values - } - - #[inline] - fn size(&self) -> usize { - self.size - } -} - impl PartialEq for MutableFixedSizeBinaryArray { fn eq(&self, other: &Self) -> bool { self.iter().eq(other.iter()) diff --git a/crates/polars-arrow/src/array/fixed_size_list/mod.rs b/crates/polars-arrow/src/array/fixed_size_list/mod.rs index 612e134a5a5da..2cefb2e8ddaf0 100644 --- a/crates/polars-arrow/src/array/fixed_size_list/mod.rs +++ b/crates/polars-arrow/src/array/fixed_size_list/mod.rs @@ -111,6 +111,7 @@ impl FixedSizeListArray { /// Slices this [`FixedSizeListArray`]. /// # Implementation /// This operation is `O(1)`. + /// /// # Safety /// The caller must ensure that `offset + length <= self.len()`. pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { @@ -156,6 +157,7 @@ impl FixedSizeListArray { } /// Returns the `Vec` at position `i`. + /// /// # Safety /// Caller must ensure that `i < self.len()` #[inline] diff --git a/crates/polars-arrow/src/array/growable/binview.rs b/crates/polars-arrow/src/array/growable/binview.rs index 200030f860e1e..affcb472cbec9 100644 --- a/crates/polars-arrow/src/array/growable/binview.rs +++ b/crates/polars-arrow/src/array/growable/binview.rs @@ -43,7 +43,6 @@ pub struct GrowableBinaryViewArray<'a, T: ViewType + ?Sized> { // See: #14201 buffers: PlIndexSet>, total_bytes_len: usize, - total_buffer_len: usize, } impl<'a, T: ViewType + ?Sized> GrowableBinaryViewArray<'a, T> { @@ -73,10 +72,6 @@ impl<'a, T: ViewType + ?Sized> GrowableBinaryViewArray<'a, T> { .map(|buf| BufferKey { inner: buf }) }) .collect::>(); - let total_buffer_len = arrays - .iter() - .map(|arr| arr.data_buffers().len()) - .sum::(); Self { arrays, @@ -85,27 +80,33 @@ impl<'a, T: ViewType + ?Sized> GrowableBinaryViewArray<'a, T> { views: Vec::with_capacity(capacity), buffers, total_bytes_len: 0, - total_buffer_len, } } fn to(&mut self) -> BinaryViewArrayGeneric { let views = std::mem::take(&mut self.views); let buffers = std::mem::take(&mut self.buffers); + let mut total_buffer_len = 0; + let buffers = Arc::from( + buffers + .into_iter() + .map(|buf| { + let buf = buf.inner; + total_buffer_len += buf.len(); + buf.clone() + }) + .collect::>(), + ); let validity = self.validity.take(); + unsafe { BinaryViewArrayGeneric::::new_unchecked( self.data_type.clone(), views.into(), - Arc::from( - buffers - .into_iter() - .map(|buf| buf.inner.clone()) - .collect::>(), - ), + buffers, validity.map(|v| v.into()), self.total_bytes_len, - self.total_buffer_len, + total_buffer_len, ) .maybe_gc() } @@ -140,6 +141,7 @@ impl<'a, T: ViewType + ?Sized> GrowableBinaryViewArray<'a, T> { #[inline] /// Ignores the buffers and doesn't update the view. This is only correct in a filter. + /// /// # Safety /// doesn't check bounds pub unsafe fn extend_unchecked_no_buffers(&mut self, index: usize, start: usize, len: usize) { @@ -187,22 +189,7 @@ impl<'a, T: ViewType + ?Sized> Growable<'a> for GrowableBinaryViewArray<'a, T> { } impl<'a, T: ViewType + ?Sized> From> for BinaryViewArrayGeneric { - fn from(val: GrowableBinaryViewArray<'a, T>) -> Self { - unsafe { - BinaryViewArrayGeneric::::new_unchecked( - val.data_type, - val.views.into(), - Arc::from( - val.buffers - .into_iter() - .map(|buf| buf.inner.clone()) - .collect::>(), - ), - val.validity.map(|v| v.into()), - val.total_bytes_len, - val.total_buffer_len, - ) - .maybe_gc() - } + fn from(mut val: GrowableBinaryViewArray<'a, T>) -> Self { + val.to() } } diff --git a/crates/polars-arrow/src/array/growable/mod.rs b/crates/polars-arrow/src/array/growable/mod.rs index aea9cdd8789e6..ca8fc87a5a86d 100644 --- a/crates/polars-arrow/src/array/growable/mod.rs +++ b/crates/polars-arrow/src/array/growable/mod.rs @@ -1,8 +1,6 @@ //! Contains the trait [`Growable`] and corresponding concreate implementations, one per concrete array, //! that offer the ability to create a new [`Array`] out of slices of existing [`Array`]s. -use std::sync::Arc; - use crate::array::*; use crate::datatypes::*; @@ -37,11 +35,13 @@ mod utils; pub trait Growable<'a> { /// Extends this [`Growable`] with elements from the bounded [`Array`] at index `index` from /// a slice starting at `start` and length `len`. + /// /// # Safety /// Doesn't do any bound checks unsafe fn extend(&mut self, index: usize, start: usize, len: usize); /// Extends this [`Growable`] with null elements, disregarding the bound arrays + /// /// # Safety /// Doesn't do any bound checks fn extend_validity(&mut self, additional: usize); diff --git a/crates/polars-arrow/src/array/indexable.rs b/crates/polars-arrow/src/array/indexable.rs index d3f466722aa63..dbf6b75c4bcf8 100644 --- a/crates/polars-arrow/src/array/indexable.rs +++ b/crates/polars-arrow/src/array/indexable.rs @@ -1,8 +1,9 @@ use std::borrow::Borrow; use crate::array::{ - MutableArray, MutableBinaryArray, MutableBinaryValuesArray, MutableBooleanArray, - MutableFixedSizeBinaryArray, MutablePrimitiveArray, MutableUtf8Array, MutableUtf8ValuesArray, + MutableArray, MutableBinaryArray, MutableBinaryValuesArray, MutableBinaryViewArray, + MutableBooleanArray, MutableFixedSizeBinaryArray, MutablePrimitiveArray, MutableUtf8Array, + MutableUtf8ValuesArray, ViewType, }; use crate::offset::Offset; use crate::types::NativeType; @@ -22,6 +23,7 @@ pub trait Indexable { fn value_at(&self, index: usize) -> Self::Value<'_>; /// Returns the element at index `i`. + /// /// # Safety /// Assumes that the `i < self.len`. #[inline] @@ -124,6 +126,26 @@ impl AsIndexed for &[u8] { } } +impl Indexable for MutableBinaryViewArray { + type Value<'a> = &'a T; + type Type = T; + + fn value_at(&self, index: usize) -> Self::Value<'_> { + self.value(index) + } + + unsafe fn value_unchecked_at(&self, index: usize) -> Self::Value<'_> { + self.value_unchecked(index) + } +} + +impl AsIndexed> for &T { + #[inline] + fn as_indexed(&self) -> &T { + self + } +} + // TODO: should NativeType derive from Hash? impl Indexable for MutablePrimitiveArray { type Value<'a> = T; diff --git a/crates/polars-arrow/src/array/list/mod.rs b/crates/polars-arrow/src/array/list/mod.rs index f820560c5a775..6c2934fa10611 100644 --- a/crates/polars-arrow/src/array/list/mod.rs +++ b/crates/polars-arrow/src/array/list/mod.rs @@ -114,6 +114,7 @@ impl ListArray { } /// Slices this [`ListArray`]. + /// /// # Safety /// The caller must ensure that `offset + length < self.len()`. pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { @@ -149,6 +150,7 @@ impl ListArray { } /// Returns the element at index `i` as &str + /// /// # Safety /// Assumes that the `i < self.len`. #[inline] diff --git a/crates/polars-arrow/src/array/map/mod.rs b/crates/polars-arrow/src/array/map/mod.rs index d057192ef612e..c6ebfc353a06f 100644 --- a/crates/polars-arrow/src/array/map/mod.rs +++ b/crates/polars-arrow/src/array/map/mod.rs @@ -111,6 +111,7 @@ impl MapArray { } /// Returns a slice of this [`MapArray`]. + /// /// # Safety /// The caller must ensure that `offset + length < self.len()`. #[inline] @@ -168,6 +169,7 @@ impl MapArray { } /// Returns the element at index `i`. + /// /// # Safety /// Assumes that the `i < self.len`. #[inline] diff --git a/crates/polars-arrow/src/array/mod.rs b/crates/polars-arrow/src/array/mod.rs index 5dfe63e2a7470..93bb166edfa3d 100644 --- a/crates/polars-arrow/src/array/mod.rs +++ b/crates/polars-arrow/src/array/mod.rs @@ -75,6 +75,7 @@ pub trait Array: Send + Sync + dyn_clone::DynClone + 'static { } /// Returns whether slot `i` is null. + /// /// # Safety /// The caller must ensure `i < self.len()` #[inline] @@ -103,6 +104,7 @@ pub trait Array: Send + Sync + dyn_clone::DynClone + 'static { /// Slices the [`Array`]. /// # Implementation /// This operation is `O(1)`. + /// /// # Safety /// The caller must ensure that `offset + length <= self.len()` unsafe fn slice_unchecked(&mut self, offset: usize, length: usize); @@ -123,6 +125,7 @@ pub trait Array: Send + Sync + dyn_clone::DynClone + 'static { /// # Implementation /// This operation is `O(1)` over `len`, as it amounts to increase two ref counts /// and moving the struct to the heap. + /// /// # Safety /// The caller must ensure that `offset + length <= self.len()` #[must_use] @@ -143,15 +146,6 @@ pub trait Array: Send + Sync + dyn_clone::DynClone + 'static { dyn_clone::clone_trait_object!(Array); -/// A trait describing an array with a backing store that can be preallocated to -/// a given size. -pub(crate) trait Container { - /// Create this array with a given capacity. - fn with_capacity(capacity: usize) -> Self - where - Self: Sized; -} - /// A trait describing a mutable array; i.e. an array whose values can be changed. /// Mutable arrays cannot be cloned but can be mutated in place, /// thereby making them useful to perform numeric operations without allocations. @@ -496,6 +490,7 @@ macro_rules! impl_sliced { /// Returns this array sliced. /// # Implementation /// This function is `O(1)`. + /// /// # Safety /// The caller must ensure that `offset + length <= self.len()`. #[inline] @@ -705,7 +700,7 @@ mod values; pub use binary::{BinaryArray, BinaryValueIter, MutableBinaryArray, MutableBinaryValuesArray}; pub use binview::{ BinaryViewArray, BinaryViewArrayGeneric, MutableBinaryViewArray, MutablePlBinary, - MutablePlString, Utf8ViewArray, View, ViewType, + MutablePlString, Utf8ViewArray, View, ViewType, INLINE_VIEW_SIZE, }; pub use boolean::{BooleanArray, MutableBooleanArray}; pub use dictionary::{DictionaryArray, DictionaryKey, MutableDictionaryArray}; @@ -746,6 +741,7 @@ pub trait TryPush { /// A trait describing the ability of a struct to receive new items. pub trait PushUnchecked { /// Push a new element that holds the invariants of the struct. + /// /// # Safety /// The items must uphold the invariants of the struct /// Read the specific implementation of the trait to understand what these are. diff --git a/crates/polars-arrow/src/array/null.rs b/crates/polars-arrow/src/array/null.rs index 900fb14005bbc..753700abba16c 100644 --- a/crates/polars-arrow/src/array/null.rs +++ b/crates/polars-arrow/src/array/null.rs @@ -62,6 +62,7 @@ impl NullArray { } /// Returns a slice of the [`NullArray`]. + /// /// # Safety /// The caller must ensure that `offset + length < self.len()`. pub unsafe fn slice_unchecked(&mut self, _offset: usize, length: usize) { diff --git a/crates/polars-arrow/src/array/primitive/from_natural.rs b/crates/polars-arrow/src/array/primitive/from_natural.rs index 0530c748af7e0..a70259a8eeff8 100644 --- a/crates/polars-arrow/src/array/primitive/from_natural.rs +++ b/crates/polars-arrow/src/array/primitive/from_natural.rs @@ -1,5 +1,3 @@ -use std::iter::FromIterator; - use super::{MutablePrimitiveArray, PrimitiveArray}; use crate::types::NativeType; diff --git a/crates/polars-arrow/src/array/primitive/mod.rs b/crates/polars-arrow/src/array/primitive/mod.rs index 0dc6992918fb9..8f70d233baeb5 100644 --- a/crates/polars-arrow/src/array/primitive/mod.rs +++ b/crates/polars-arrow/src/array/primitive/mod.rs @@ -208,6 +208,7 @@ impl PrimitiveArray { /// Returns the value at index `i`. /// The value on null slots is undetermined (it can be anything). + /// /// # Safety /// Caller must be sure that `i < self.len()` #[inline] @@ -243,6 +244,7 @@ impl PrimitiveArray { /// Slices this [`PrimitiveArray`] by an offset and length. /// # Implementation /// This operation is `O(1)`. + /// /// # Safety /// The caller must ensure that `offset + length <= self.len()`. #[inline] @@ -420,6 +422,7 @@ impl PrimitiveArray { } /// Creates a new [`PrimitiveArray`] from an iterator over values + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. @@ -433,6 +436,7 @@ impl PrimitiveArray { } /// Creates a [`PrimitiveArray`] from an iterator of optional values. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. diff --git a/crates/polars-arrow/src/array/primitive/mutable.rs b/crates/polars-arrow/src/array/primitive/mutable.rs index fcf8032af7d99..38989bf1b1479 100644 --- a/crates/polars-arrow/src/array/primitive/mutable.rs +++ b/crates/polars-arrow/src/array/primitive/mutable.rs @@ -1,8 +1,6 @@ -use std::iter::FromIterator; use std::sync::Arc; use polars_error::PolarsResult; -use polars_utils::total_ord::TotalOrdWrap; use super::{check, PrimitiveArray}; use crate::array::physical_binary::extend_validity; @@ -129,17 +127,20 @@ impl MutablePrimitiveArray { } } + #[inline] + pub fn push_value(&mut self, value: T) { + self.values.push(value); + match &mut self.validity { + Some(validity) => validity.push(true), + None => {}, + } + } + /// Adds a new value to the array. #[inline] pub fn push(&mut self, value: Option) { match value { - Some(value) => { - self.values.push(value); - match &mut self.validity { - Some(validity) => validity.push(true), - None => {}, - } - }, + Some(value) => self.push_value(value), None => { self.values.push(T::default()); match &mut self.validity { @@ -195,6 +196,7 @@ impl MutablePrimitiveArray { } /// Extends the [`MutablePrimitiveArray`] from an iterator of trusted len. + /// /// # Safety /// The iterator must be trusted len. #[inline] @@ -224,6 +226,7 @@ impl MutablePrimitiveArray { /// Extends the [`MutablePrimitiveArray`] from an iterator of values of trusted len. /// This differs from `extend_trusted_len_unchecked` which accepts in iterator of optional values. + /// /// # Safety /// The iterator must be trusted len. #[inline] @@ -288,6 +291,24 @@ impl MutablePrimitiveArray { pub fn freeze(self) -> PrimitiveArray { self.into() } + + /// Clears the array, removing all values. + /// + /// Note that this method has no effect on the allocated capacity + /// of the array. + pub fn clear(&mut self) { + self.values.clear(); + self.validity = None; + } + + /// Apply a function that temporarily freezes this `MutableArray` into a `PrimitiveArray`. + pub fn with_freeze) -> K>(&mut self, f: F) -> K { + let mutable = std::mem::take(self); + let arr = mutable.freeze(); + let out = f(&arr); + *self = arr.into_mut().right().unwrap(); + out + } } /// Accessors @@ -320,6 +341,7 @@ impl MutablePrimitiveArray { /// Sets position `index` to `value`. /// Note that if it is the first time a null appears in this array, /// this initializes the validity bitmap (`O(N)`). + /// /// # Safety /// Caller must ensure `index < self.len()` pub unsafe fn set_unchecked(&mut self, index: usize, value: Option) { @@ -364,14 +386,6 @@ impl Extend> for MutablePrimitiveArray { } } -impl Extend>> for MutablePrimitiveArray { - fn extend>>>(&mut self, iter: I) { - let iter = iter.into_iter(); - self.reserve(iter.size_hint().0); - iter.for_each(|x| self.push(x.map(|x| x.0))) - } -} - impl TryExtend> for MutablePrimitiveArray { /// This is infallible and is implemented for consistency with all other types fn try_extend>>(&mut self, iter: I) -> PolarsResult<()> { @@ -448,6 +462,7 @@ impl MutablePrimitiveArray { } /// Creates a [`MutablePrimitiveArray`] from an iterator of trusted length. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. `size_hint().1` correctly reports its length. @@ -477,6 +492,7 @@ impl MutablePrimitiveArray { } /// Creates a [`MutablePrimitiveArray`] from an fallible iterator of trusted length. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. @@ -525,6 +541,7 @@ impl MutablePrimitiveArray { } /// Creates a new [`MutablePrimitiveArray`] from an iterator over values + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. diff --git a/crates/polars-arrow/src/array/static_array_collect.rs b/crates/polars-arrow/src/array/static_array_collect.rs index 2da262cce3a05..9413b0a167789 100644 --- a/crates/polars-arrow/src/array/static_array_collect.rs +++ b/crates/polars-arrow/src/array/static_array_collect.rs @@ -807,12 +807,14 @@ impl ArrayFromIter> for BooleanArray { // as Rust considers that AsRef for Option<&dyn Array> could be implemented. trait AsArray { fn as_array(&self) -> &dyn Array; + #[cfg(feature = "dtype-array")] fn into_boxed_array(self) -> Box; // Prevents unnecessary re-boxing. } impl AsArray for Box { fn as_array(&self) -> &dyn Array { self.as_ref() } + #[cfg(feature = "dtype-array")] fn into_boxed_array(self) -> Box { self } @@ -821,6 +823,7 @@ impl<'a> AsArray for &'a dyn Array { fn as_array(&self) -> &'a dyn Array { *self } + #[cfg(feature = "dtype-array")] fn into_boxed_array(self) -> Box { self.to_boxed() } diff --git a/crates/polars-arrow/src/array/struct_/mod.rs b/crates/polars-arrow/src/array/struct_/mod.rs index 6f796ac18ac42..21d3247bbc85a 100644 --- a/crates/polars-arrow/src/array/struct_/mod.rs +++ b/crates/polars-arrow/src/array/struct_/mod.rs @@ -178,6 +178,7 @@ impl StructArray { /// Slices this [`StructArray`]. /// # Implementation /// This operation is `O(F)` where `F` is the number of fields. + /// /// # Safety /// The caller must ensure that `offset + length <= self.len()`. pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { diff --git a/crates/polars-arrow/src/array/union/mod.rs b/crates/polars-arrow/src/array/union/mod.rs index 6c4e153c2dfa3..86d7cbed7397b 100644 --- a/crates/polars-arrow/src/array/union/mod.rs +++ b/crates/polars-arrow/src/array/union/mod.rs @@ -240,6 +240,7 @@ impl UnionArray { /// Returns a slice of this [`UnionArray`]. /// # Implementation /// This operation is `O(F)` where `F` is the number of fields. + /// /// # Safety /// The caller must ensure that `offset + length <= self.len()`. #[inline] @@ -296,6 +297,7 @@ impl UnionArray { /// Returns the index and slot of the field to select from `self.fields`. /// The first value is guaranteed to be `< self.fields().len()` + /// /// # Safety /// This function is safe iff `index < self.len`. #[inline] @@ -323,6 +325,7 @@ impl UnionArray { } /// Returns the slot `index` as a [`Scalar`]. + /// /// # Safety /// This function is safe iff `i < self.len`. pub unsafe fn value_unchecked(&self, index: usize) -> Box { diff --git a/crates/polars-arrow/src/array/utf8/from.rs b/crates/polars-arrow/src/array/utf8/from.rs index c1dcaf09b10d2..6f90bac994959 100644 --- a/crates/polars-arrow/src/array/utf8/from.rs +++ b/crates/polars-arrow/src/array/utf8/from.rs @@ -1,5 +1,3 @@ -use std::iter::FromIterator; - use super::{MutableUtf8Array, Utf8Array}; use crate::offset::Offset; diff --git a/crates/polars-arrow/src/array/utf8/mod.rs b/crates/polars-arrow/src/array/utf8/mod.rs index f58534b0ae106..218e71323abf7 100644 --- a/crates/polars-arrow/src/array/utf8/mod.rs +++ b/crates/polars-arrow/src/array/utf8/mod.rs @@ -156,6 +156,7 @@ impl Utf8Array { } /// Returns the value of the element at index `i`, ignoring the array's validity. + /// /// # Safety /// This function is safe iff `i < self.len`. #[inline] @@ -223,6 +224,7 @@ impl Utf8Array { /// Slices this [`Utf8Array`]. /// # Implementation /// This function is `O(1)` + /// /// # Safety /// The caller must ensure that `offset + length <= self.len()`. pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { @@ -362,6 +364,7 @@ impl Utf8Array { /// * The last offset is not equal to the values' length. /// * the validity's length is not equal to `offsets.len()`. /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to either `Utf8` or `LargeUtf8`. + /// /// # Safety /// This function is unsound iff: /// * The `values` between two consecutive `offsets` are not valid utf8 @@ -430,6 +433,7 @@ impl Utf8Array { } /// Creates a [`Utf8Array`] from an iterator of trusted length. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. @@ -453,6 +457,7 @@ impl Utf8Array { } /// Creates a [`Utf8Array`] from an falible iterator of trusted length. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. diff --git a/crates/polars-arrow/src/array/utf8/mutable.rs b/crates/polars-arrow/src/array/utf8/mutable.rs index a67d2e00b4c2f..ef9a5e8527b71 100644 --- a/crates/polars-arrow/src/array/utf8/mutable.rs +++ b/crates/polars-arrow/src/array/utf8/mutable.rs @@ -1,4 +1,3 @@ -use std::iter::FromIterator; use std::sync::Arc; use polars_error::{polars_bail, PolarsResult}; @@ -75,6 +74,7 @@ impl MutableUtf8Array { } /// Create a [`MutableUtf8Array`] out of low-end APIs. + /// /// # Safety /// The caller must ensure that every value between offsets is a valid utf8. /// # Panics @@ -145,14 +145,13 @@ impl MutableUtf8Array { } /// Returns the value of the element at index `i`, ignoring the array's validity. - /// # Safety - /// This function is safe iff `i < self.len`. #[inline] pub fn value(&self, i: usize) -> &str { self.values.value(i) } /// Returns the value of the element at index `i`, ignoring the array's validity. + /// /// # Safety /// This function is safe iff `i < self.len`. #[inline] @@ -329,6 +328,7 @@ impl MutableUtf8Array { /// Extends the [`MutableUtf8Array`] from an iterator of values of trusted len. /// This differs from `extended_trusted_len_unchecked` which accepts iterator of optional /// values. + /// /// # Safety /// The iterator must be trusted len. #[inline] @@ -357,6 +357,7 @@ impl MutableUtf8Array { } /// Extends [`MutableUtf8Array`] from an iterator of trusted len. + /// /// # Safety /// The iterator must be trusted len. #[inline] @@ -376,6 +377,7 @@ impl MutableUtf8Array { } /// Creates a [`MutableUtf8Array`] from an iterator of trusted length. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. @@ -404,6 +406,7 @@ impl MutableUtf8Array { } /// Creates a [`MutableUtf8Array`] from an iterator of trusted length of `&str`. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. @@ -440,6 +443,7 @@ impl MutableUtf8Array { } /// Creates a [`MutableUtf8Array`] from an falible iterator of trusted length. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. diff --git a/crates/polars-arrow/src/array/utf8/mutable_values.rs b/crates/polars-arrow/src/array/utf8/mutable_values.rs index e76e62c88243f..ce3c2f71f20cd 100644 --- a/crates/polars-arrow/src/array/utf8/mutable_values.rs +++ b/crates/polars-arrow/src/array/utf8/mutable_values.rs @@ -1,4 +1,3 @@ -use std::iter::FromIterator; use std::sync::Arc; use polars_error::{polars_bail, PolarsResult}; @@ -95,6 +94,7 @@ impl MutableUtf8ValuesArray { /// This function does not panic iff: /// * The last offset is equal to the values' length. /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is equal to either `Utf8` or `LargeUtf8`. + /// /// # Safety /// This function is safe iff: /// * the offsets are monotonically increasing @@ -201,6 +201,7 @@ impl MutableUtf8ValuesArray { } /// Returns the value of the element at index `i`. + /// /// # Safety /// This function is safe iff `i < self.len`. #[inline] @@ -309,6 +310,7 @@ impl MutableUtf8ValuesArray { } /// Extends [`MutableUtf8ValuesArray`] from an iterator of trusted len. + /// /// # Safety /// The iterator must be trusted len. #[inline] @@ -333,6 +335,7 @@ impl MutableUtf8ValuesArray { } /// Returns a new [`MutableUtf8ValuesArray`] from an iterator of trusted length. + /// /// # Safety /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). /// I.e. that `size_hint().1` correctly reports its length. diff --git a/crates/polars-arrow/src/bitmap/bitmask.rs b/crates/polars-arrow/src/bitmap/bitmask.rs index 8594ecf30fb81..67785a49eedab 100644 --- a/crates/polars-arrow/src/bitmap/bitmask.rs +++ b/crates/polars-arrow/src/bitmap/bitmask.rs @@ -1,6 +1,8 @@ #[cfg(feature = "simd")] use std::simd::{LaneCount, Mask, MaskElement, SupportedLaneCount}; +use polars_utils::slice::load_padded_le_u64; + use crate::bitmap::Bitmap; /// Returns the nth set bit in w, if n+1 bits are set. The indexing is @@ -68,29 +70,6 @@ fn nth_set_bit_u32(w: u32, n: u32) -> Option { } } -// Loads a u64 from the given byteslice, as if it were padded with zeros. -fn load_padded_le_u64(bytes: &[u8]) -> u64 { - let len = bytes.len(); - if len >= 8 { - return u64::from_le_bytes(bytes[0..8].try_into().unwrap()); - } - - if len >= 4 { - let lo = u32::from_le_bytes(bytes[0..4].try_into().unwrap()); - let hi = u32::from_le_bytes(bytes[len - 4..len].try_into().unwrap()); - return (lo as u64) | ((hi as u64) << (8 * (len - 4))); - } - - if len == 0 { - return 0; - } - - let lo = bytes[0] as u64; - let mid = (bytes[len / 2] as u64) << (8 * (len / 2)); - let hi = (bytes[len - 1] as u64) << (8 * (len - 1)); - lo | mid | hi -} - #[derive(Default, Clone)] pub struct BitMask<'a> { bytes: &'a [u8], diff --git a/crates/polars-arrow/src/bitmap/immutable.rs b/crates/polars-arrow/src/bitmap/immutable.rs index 1f1c3675031c3..8e156939c3cd4 100644 --- a/crates/polars-arrow/src/bitmap/immutable.rs +++ b/crates/polars-arrow/src/bitmap/immutable.rs @@ -1,4 +1,3 @@ -use std::iter::FromIterator; use std::ops::Deref; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; @@ -8,7 +7,11 @@ use polars_error::{polars_bail, PolarsResult}; use super::utils::{count_zeros, fmt, get_bit, get_bit_unchecked, BitChunk, BitChunks, BitmapIter}; use super::{chunk_iter_to_vec, IntoIter, MutableBitmap}; +use crate::bitmap::iterator::{ + FastU32BitmapIter, FastU56BitmapIter, FastU64BitmapIter, TrueIdxIter, +}; use crate::buffer::Bytes; +use crate::legacy::utils::FromTrustedLenIterator; use crate::trusted_len::TrustedLen; const UNKNOWN_BIT_COUNT: u64 = u64::MAX; @@ -141,6 +144,29 @@ impl Bitmap { BitChunks::new(&self.bytes, self.offset, self.length) } + /// Returns a fast iterator that gives 32 bits at a time. + /// Has a remainder that must be handled separately. + pub fn fast_iter_u32(&self) -> FastU32BitmapIter<'_> { + FastU32BitmapIter::new(&self.bytes, self.offset, self.length) + } + + /// Returns a fast iterator that gives 56 bits at a time. + /// Has a remainder that must be handled separately. + pub fn fast_iter_u56(&self) -> FastU56BitmapIter<'_> { + FastU56BitmapIter::new(&self.bytes, self.offset, self.length) + } + + /// Returns a fast iterator that gives 64 bits at a time. + /// Has a remainder that must be handled separately. + pub fn fast_iter_u64(&self) -> FastU64BitmapIter<'_> { + FastU64BitmapIter::new(&self.bytes, self.offset, self.length) + } + + /// Returns an iterator that only iterates over the set bits. + pub fn true_idx_iter(&self) -> TrueIdxIter<'_> { + TrueIdxIter::new(self.len(), Some(self)) + } + /// Returns the byte slice of this [`Bitmap`]. /// /// The returned tuple contains: @@ -159,6 +185,22 @@ impl Bitmap { ) } + /// Returns the number of set bits on this [`Bitmap`]. + /// + /// See `unset_bits` for details. + #[inline] + pub fn set_bits(&self) -> usize { + self.length - self.unset_bits() + } + + /// Returns the number of set bits on this [`Bitmap`] if it is known. + /// + /// See `lazy_unset_bits` for details. + #[inline] + pub fn lazy_set_bits(&self) -> Option { + Some(self.length - self.lazy_unset_bits()?) + } + /// Returns the number of unset bits on this [`Bitmap`]. /// /// Guaranteed to be `<= self.len()`. @@ -179,6 +221,30 @@ impl Bitmap { } } + /// Returns the number of unset bits on this [`Bitmap`] if it is known. + /// + /// Guaranteed to be `<= self.len()`. + pub fn lazy_unset_bits(&self) -> Option { + let cache = self.unset_bit_count_cache.load(Ordering::Relaxed); + if cache >> 63 != 0 { + None + } else { + Some(cache as usize) + } + } + + /// Updates the count of the number of set bits on this [`Bitmap`]. + /// + /// # Safety + /// + /// The number of set bits must be correct. + pub unsafe fn update_bit_count(&mut self, bits_set: usize) { + assert!(bits_set <= self.length); + let zeros = self.length - bits_set; + self.unset_bit_count_cache + .store(zeros as u64, Ordering::Relaxed); + } + /// Slices `self`, offsetting by `offset` and truncating up to `length` bits. /// # Panic /// Panics iff `offset + length > self.length`, i.e. if the offset and `length` @@ -190,6 +256,7 @@ impl Bitmap { } /// Slices `self`, offsetting by `offset` and truncating up to `length` bits. + /// /// # Safety /// The caller must ensure that `self.offset + offset + length <= self.len()` #[inline] @@ -246,6 +313,7 @@ impl Bitmap { } /// Slices `self`, offsetting by `offset` and truncating up to `length` bits. + /// /// # Safety /// The caller must ensure that `self.offset + offset + length <= self.len()` #[inline] @@ -264,6 +332,7 @@ impl Bitmap { } /// Unsafely returns whether the bit at position `i` is set. + /// /// # Safety /// Unsound iff `i >= self.len()`. #[inline] @@ -358,7 +427,7 @@ impl Bitmap { /// Alias for `Bitmap::try_new().unwrap()` /// This function is `O(1)` /// # Panic - /// This function panics iff `length <= bytes.len() * 8` + /// This function panics iff `length > bytes.len() * 8` #[inline] pub fn from_u8_vec(vec: Vec, length: usize) -> Self { Bitmap::try_new(vec, length).unwrap() @@ -416,8 +485,18 @@ impl FromIterator for Bitmap { } } +impl FromTrustedLenIterator for Bitmap { + fn from_iter_trusted_length>(iter: T) -> Self + where + T::IntoIter: TrustedLen, + { + MutableBitmap::from_trusted_len_iter(iter.into_iter()).into() + } +} + impl Bitmap { /// Creates a new [`Bitmap`] from an iterator of booleans. + /// /// # Safety /// The iterator must report an accurate length. #[inline] @@ -440,6 +519,7 @@ impl Bitmap { } /// Creates a new [`Bitmap`] from a fallible iterator of booleans. + /// /// # Safety /// The iterator must report an accurate length. #[inline] diff --git a/crates/polars-arrow/src/bitmap/iterator.rs b/crates/polars-arrow/src/bitmap/iterator.rs index 2bb812adb68f7..63c61afd83f72 100644 --- a/crates/polars-arrow/src/bitmap/iterator.rs +++ b/crates/polars-arrow/src/bitmap/iterator.rs @@ -1,7 +1,30 @@ +use polars_utils::slice::load_padded_le_u64; + use super::bitmask::BitMask; use super::Bitmap; use crate::trusted_len::TrustedLen; +/// Calculates how many iterations are remaining, assuming: +/// - We have length elements left. +/// - We need max(consume, min_length_for_iter) elements to start a new iteration. +/// - On each iteration we consume the given amount of elements. +fn calc_iters_remaining(length: usize, min_length_for_iter: usize, consume: usize) -> usize { + let min_length_for_iter = min_length_for_iter.max(consume); + if length < min_length_for_iter { + return 0; + } + + let obvious_part = length - min_length_for_iter; + let obvious_iters = obvious_part / consume; + // let obvious_part_remaining = obvious_part % consume; + // let total_remaining = min_length_for_iter + obvious_part_remaining; + // assert!(total_remaining >= min_length_for_iter); // We have at least 1 more iter. + // assert!(obvious_part_remaining < consume); // Basic modulo property. + // assert!(total_remaining < min_length_for_iter + consume); // Add min_length_for_iter to both sides. + // assert!(total_remaining - consume < min_length_for_iter); // Not enough remaining after 1 iter. + 1 + obvious_iters // Thus always exactly 1 more iter. +} + pub struct TrueIdxIter<'a> { mask: BitMask<'a>, first_unknown: usize, @@ -71,6 +94,262 @@ impl<'a> Iterator for TrueIdxIter<'a> { unsafe impl<'a> TrustedLen for TrueIdxIter<'a> {} +pub struct FastU32BitmapIter<'a> { + bytes: &'a [u8], + shift: u32, + bits_left: usize, +} + +impl<'a> FastU32BitmapIter<'a> { + pub fn new(bytes: &'a [u8], offset: usize, len: usize) -> Self { + assert!(bytes.len() * 8 >= offset + len); + let shift = (offset % 8) as u32; + let bytes = &bytes[offset / 8..]; + Self { + bytes, + shift, + bits_left: len, + } + } + + // The iteration logic that would normally follow the fast-path. + fn next_remainder(&mut self) -> Option { + if self.bits_left > 0 { + let word = load_padded_le_u64(self.bytes); + let mask; + if self.bits_left >= 32 { + mask = u32::MAX; + self.bits_left -= 32; + self.bytes = unsafe { self.bytes.get_unchecked(4..) }; + } else { + mask = (1 << self.bits_left) - 1; + self.bits_left = 0; + } + + return Some((word >> self.shift) as u32 & mask); + } + + None + } + + /// Returns the remainder bits and how many there are, + /// assuming the iterator was fully consumed. + pub fn remainder(mut self) -> (u64, usize) { + let bits_left = self.bits_left; + let lo = self.next_remainder().unwrap_or(0); + let hi = self.next_remainder().unwrap_or(0); + (((hi as u64) << 32) | (lo as u64), bits_left) + } +} + +impl<'a> Iterator for FastU32BitmapIter<'a> { + type Item = u32; + + #[inline] + fn next(&mut self) -> Option { + // Fast path, can load a whole u64. + if self.bits_left >= 64 { + let chunk; + unsafe { + // SAFETY: bits_left ensures this is in-bounds. + chunk = self.bytes.get_unchecked(0..8); + self.bytes = self.bytes.get_unchecked(4..); + } + self.bits_left -= 32; + let word = u64::from_le_bytes(chunk.try_into().unwrap()); + return Some((word >> self.shift) as u32); + } + + None + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let hint = calc_iters_remaining(self.bits_left, 64, 32); + (hint, Some(hint)) + } +} + +unsafe impl<'a> TrustedLen for FastU32BitmapIter<'a> {} + +pub struct FastU56BitmapIter<'a> { + bytes: &'a [u8], + shift: u32, + bits_left: usize, +} + +impl<'a> FastU56BitmapIter<'a> { + pub fn new(bytes: &'a [u8], offset: usize, len: usize) -> Self { + assert!(bytes.len() * 8 >= offset + len); + let shift = (offset % 8) as u32; + let bytes = &bytes[offset / 8..]; + Self { + bytes, + shift, + bits_left: len, + } + } + + // The iteration logic that would normally follow the fast-path. + fn next_remainder(&mut self) -> Option { + if self.bits_left > 0 { + let word = load_padded_le_u64(self.bytes); + let mask; + if self.bits_left >= 56 { + mask = (1 << 56) - 1; + self.bits_left -= 56; + self.bytes = unsafe { self.bytes.get_unchecked(7..) }; + } else { + mask = (1 << self.bits_left) - 1; + self.bits_left = 0; + }; + + return Some((word >> self.shift) & mask); + } + + None + } + + /// Returns the remainder bits and how many there are, + /// assuming the iterator was fully consumed. Output is safe but + /// not specified if the iterator wasn't fully consumed. + pub fn remainder(mut self) -> (u64, usize) { + let bits_left = self.bits_left; + let lo = self.next_remainder().unwrap_or(0); + let hi = self.next_remainder().unwrap_or(0); + ((hi << 56) | lo, bits_left) + } +} + +impl<'a> Iterator for FastU56BitmapIter<'a> { + type Item = u64; + + #[inline] + fn next(&mut self) -> Option { + // Fast path, can load a whole u64. + if self.bits_left >= 64 { + let chunk; + unsafe { + // SAFETY: bits_left ensures this is in-bounds. + chunk = self.bytes.get_unchecked(0..8); + self.bytes = self.bytes.get_unchecked(7..); + self.bits_left -= 56; + } + + let word = u64::from_le_bytes(chunk.try_into().unwrap()); + let mask = (1 << 56) - 1; + return Some((word >> self.shift) & mask); + } + + None + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let hint = calc_iters_remaining(self.bits_left, 64, 56); + (hint, Some(hint)) + } +} + +unsafe impl<'a> TrustedLen for FastU56BitmapIter<'a> {} + +pub struct FastU64BitmapIter<'a> { + bytes: &'a [u8], + shift: u32, + bits_left: usize, + next_word: u64, +} + +impl<'a> FastU64BitmapIter<'a> { + pub fn new(bytes: &'a [u8], offset: usize, len: usize) -> Self { + assert!(bytes.len() * 8 >= offset + len); + let shift = (offset % 8) as u32; + let bytes = &bytes[offset / 8..]; + let next_word = load_padded_le_u64(bytes); + let bytes = bytes.get(8..).unwrap_or(&[]); + Self { + bytes, + shift, + bits_left: len, + next_word, + } + } + + #[inline] + fn combine(&self, lo: u64, hi: u64) -> u64 { + // Compiles to 128-bit SHRD instruction on x86-64. + // Yes, the % 64 is important for the compiler to generate optimal code. + let wide = ((hi as u128) << 64) | lo as u128; + (wide >> (self.shift % 64)) as u64 + } + + // The iteration logic that would normally follow the fast-path. + fn next_remainder(&mut self) -> Option { + if self.bits_left > 0 { + let lo = self.next_word; + let hi = load_padded_le_u64(self.bytes); + let mask; + if self.bits_left >= 64 { + mask = u64::MAX; + self.bits_left -= 64; + self.bytes = self.bytes.get(8..).unwrap_or(&[]); + } else { + mask = (1 << self.bits_left) - 1; + self.bits_left = 0; + }; + self.next_word = hi; + + return Some(self.combine(lo, hi) & mask); + } + + None + } + + /// Returns the remainder bits and how many there are, + /// assuming the iterator was fully consumed. Output is safe but + /// not specified if the iterator wasn't fully consumed. + pub fn remainder(mut self) -> ([u64; 2], usize) { + let bits_left = self.bits_left; + let lo = self.next_remainder().unwrap_or(0); + let hi = self.next_remainder().unwrap_or(0); + ([lo, hi], bits_left) + } +} + +impl<'a> Iterator for FastU64BitmapIter<'a> { + type Item = u64; + + #[inline] + fn next(&mut self) -> Option { + // Fast path: can load two u64s in a row. + // (Note that we already loaded one in the form of self.next_word). + if self.bits_left >= 128 { + let chunk; + unsafe { + // SAFETY: bits_left ensures this is in-bounds. + chunk = self.bytes.get_unchecked(0..8); + self.bytes = self.bytes.get_unchecked(8..); + } + let lo = self.next_word; + let hi = u64::from_le_bytes(chunk.try_into().unwrap()); + self.next_word = hi; + self.bits_left -= 64; + + return Some(self.combine(lo, hi)); + } + + None + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let hint = calc_iters_remaining(self.bits_left, 128, 64); + (hint, Some(hint)) + } +} + +unsafe impl<'a> TrustedLen for FastU64BitmapIter<'a> {} + /// This crates' equivalent of [`std::vec::IntoIter`] for [`Bitmap`]. #[derive(Debug, Clone)] pub struct IntoIter { diff --git a/crates/polars-arrow/src/bitmap/mutable.rs b/crates/polars-arrow/src/bitmap/mutable.rs index a6229e7e17646..f1ef33a182246 100644 --- a/crates/polars-arrow/src/bitmap/mutable.rs +++ b/crates/polars-arrow/src/bitmap/mutable.rs @@ -1,5 +1,4 @@ use std::hint::unreachable_unchecked; -use std::iter::FromIterator; use std::sync::Arc; use polars_error::{polars_bail, PolarsResult}; @@ -77,7 +76,7 @@ impl MutableBitmap { /// # Errors /// This function errors iff `length > bytes.len() * 8` #[inline] - pub fn try_new(bytes: Vec, length: usize) -> PolarsResult { + pub fn try_new(mut bytes: Vec, length: usize) -> PolarsResult { if length > bytes.len().saturating_mul(8) { polars_bail!(InvalidOperation: "The length of the bitmap ({}) must be `<=` to the number of bytes times 8 ({})", @@ -85,6 +84,10 @@ impl MutableBitmap { bytes.len().saturating_mul(8) ) } + + // Ensure invariant holds. + let min_byte_length_needed = length.div_ceil(8); + bytes.drain(min_byte_length_needed..); Ok(Self { length, buffer: bytes, @@ -221,6 +224,7 @@ impl MutableBitmap { } /// Pushes a new bit to the [`MutableBitmap`] + /// /// # Safety /// The caller must ensure that the [`MutableBitmap`] has sufficient capacity. #[inline] @@ -318,6 +322,7 @@ impl MutableBitmap { } /// Sets the position `index` to `value` + /// /// # Safety /// Caller must ensure that `index < self.len()` #[inline] @@ -529,6 +534,7 @@ impl MutableBitmap { } /// Extends `self` from an iterator of trusted len. + /// /// # Safety /// The caller must guarantee that the iterator has a trusted len. #[inline] @@ -577,6 +583,7 @@ impl MutableBitmap { } /// Creates a new [`MutableBitmap`] from an iterator of booleans. + /// /// # Safety /// The iterator must report an accurate length. #[inline] @@ -610,6 +617,7 @@ impl MutableBitmap { } /// Creates a new [`MutableBitmap`] from an falible iterator of booleans. + /// /// # Safety /// The caller must guarantee that the iterator is `TrustedLen`. pub unsafe fn try_from_trusted_len_iter_unchecked( @@ -697,6 +705,7 @@ impl MutableBitmap { /// # Implementation /// When both [`MutableBitmap`]'s length and `offset` are both multiples of 8, /// this function performs a memcopy. Else, it first aligns bit by bit and then performs a memcopy. + /// /// # Safety /// Caller must ensure `offset + length <= slice.len() * 8` #[inline] diff --git a/crates/polars-arrow/src/bitmap/utils/chunk_iterator/chunks_exact.rs b/crates/polars-arrow/src/bitmap/utils/chunk_iterator/chunks_exact.rs index 4ab9d300ba021..7bc12e22898e2 100644 --- a/crates/polars-arrow/src/bitmap/utils/chunk_iterator/chunks_exact.rs +++ b/crates/polars-arrow/src/bitmap/utils/chunk_iterator/chunks_exact.rs @@ -1,4 +1,3 @@ -use std::convert::TryInto; use std::slice::ChunksExact; use super::{BitChunk, BitChunkIterExact}; diff --git a/crates/polars-arrow/src/bitmap/utils/chunk_iterator/mod.rs b/crates/polars-arrow/src/bitmap/utils/chunk_iterator/mod.rs index 71f56a2842749..8a1668a37d1f0 100644 --- a/crates/polars-arrow/src/bitmap/utils/chunk_iterator/mod.rs +++ b/crates/polars-arrow/src/bitmap/utils/chunk_iterator/mod.rs @@ -1,5 +1,3 @@ -use std::convert::TryInto; - mod chunks_exact; mod merge; diff --git a/crates/polars-arrow/src/bitmap/utils/mod.rs b/crates/polars-arrow/src/bitmap/utils/mod.rs index 45badf65b6f1f..4ec5786f1c4f6 100644 --- a/crates/polars-arrow/src/bitmap/utils/mod.rs +++ b/crates/polars-arrow/src/bitmap/utils/mod.rs @@ -11,69 +11,63 @@ pub use chunk_iterator::{BitChunk, BitChunkIterExact, BitChunks, BitChunksExact} pub use chunks_exact_mut::BitChunksExactMut; pub use fmt::fmt; pub use iterator::BitmapIter; +use polars_utils::slice::GetSaferUnchecked; pub use slice_iterator::SlicesIterator; pub use zip_validity::{ZipValidity, ZipValidityIter}; -const BIT_MASK: [u8; 8] = [1, 2, 4, 8, 16, 32, 64, 128]; -const UNSET_BIT_MASK: [u8; 8] = [ - 255 - 1, - 255 - 2, - 255 - 4, - 255 - 8, - 255 - 16, - 255 - 32, - 255 - 64, - 255 - 128, -]; - /// Returns whether bit at position `i` in `byte` is set or not #[inline] pub fn is_set(byte: u8, i: usize) -> bool { - (byte & BIT_MASK[i]) != 0 + debug_assert!(i < 8); + byte & (1 << i) != 0 } -/// Sets bit at position `i` in `byte` +/// Sets bit at position `i` in `byte`. #[inline] pub fn set(byte: u8, i: usize, value: bool) -> u8 { - if value { - byte | BIT_MASK[i] - } else { - byte & UNSET_BIT_MASK[i] - } + debug_assert!(i < 8); + + let mask = !(1 << i); + let insert = (value as u8) << i; + (byte & mask) | insert } -/// Sets bit at position `i` in `data` +/// Sets bit at position `i` in `bytes`. /// # Panics -/// panics if `i >= data.len() / 8` +/// This function panics iff `i >= bytes.len() * 8`. #[inline] -pub fn set_bit(data: &mut [u8], i: usize, value: bool) { - data[i / 8] = set(data[i / 8], i % 8, value); +pub fn set_bit(bytes: &mut [u8], i: usize, value: bool) { + bytes[i / 8] = set(bytes[i / 8], i % 8, value); } -/// Sets bit at position `i` in `data` without doing bound checks +/// Sets bit at position `i` in `bytes` without doing bound checks /// # Safety -/// caller must ensure that `i < data.len() / 8` +/// `i >= bytes.len() * 8` results in undefined behavior. #[inline] -pub unsafe fn set_bit_unchecked(data: &mut [u8], i: usize, value: bool) { - let byte = data.get_unchecked_mut(i / 8); +pub unsafe fn set_bit_unchecked(bytes: &mut [u8], i: usize, value: bool) { + let byte = bytes.get_unchecked_mut(i / 8); *byte = set(*byte, i % 8, value); } -/// Returns whether bit at position `i` in `data` is set +/// Returns whether bit at position `i` in `bytes` is set. /// # Panic -/// This function panics iff `i / 8 >= bytes.len()` +/// This function panics iff `i >= bytes.len() * 8`. #[inline] pub fn get_bit(bytes: &[u8], i: usize) -> bool { - is_set(bytes[i / 8], i % 8) + let byte = bytes[i / 8]; + let bit = (byte >> (i % 8)) & 1; + bit != 0 } -/// Returns whether bit at position `i` in `data` is set or not. +/// Returns whether bit at position `i` in `bytes` is set or not. /// /// # Safety -/// `i >= data.len() * 8` results in undefined behavior +/// `i >= bytes.len() * 8` results in undefined behavior. #[inline] -pub unsafe fn get_bit_unchecked(data: &[u8], i: usize) -> bool { - (*data.as_ptr().add(i >> 3) & BIT_MASK[i & 7]) != 0 +pub unsafe fn get_bit_unchecked(bytes: &[u8], i: usize) -> bool { + let byte = *bytes.get_unchecked_release(i / 8); + let bit = (byte >> (i % 8)) & 1; + bit != 0 } /// Returns the number of bytes required to hold `bits` bits. diff --git a/crates/polars-arrow/src/buffer/immutable.rs b/crates/polars-arrow/src/buffer/immutable.rs index f6d49a0cb06b9..158a0ed4f61d2 100644 --- a/crates/polars-arrow/src/buffer/immutable.rs +++ b/crates/polars-arrow/src/buffer/immutable.rs @@ -1,4 +1,3 @@ -use std::iter::FromIterator; use std::ops::Deref; use std::sync::Arc; use std::usize; @@ -120,6 +119,7 @@ impl Buffer { } /// Returns the byte slice stored in this buffer + /// /// # Safety /// `index` must be smaller than `len` #[inline] @@ -159,6 +159,7 @@ impl Buffer { /// Returns a new [`Buffer`] that is a slice of this buffer starting at `offset`. /// Doing so allows the same memory region to be shared between buffers. + /// /// # Safety /// The caller must ensure `offset + length <= self.len()` #[inline] @@ -169,6 +170,7 @@ impl Buffer { } /// Slices this buffer starting at `offset`. + /// /// # Safety /// The caller must ensure `offset + length <= self.len()` #[inline] diff --git a/crates/polars-arrow/src/buffer/mod.rs b/crates/polars-arrow/src/buffer/mod.rs index b6c29d6cf47be..46ef0af62d984 100644 --- a/crates/polars-arrow/src/buffer/mod.rs +++ b/crates/polars-arrow/src/buffer/mod.rs @@ -29,6 +29,7 @@ impl Bytes { /// Takes ownership of an allocated memory region. /// # Panics /// This function panics if and only if pointer is not null + /// /// # Safety /// This function is safe if and only if `ptr` is valid for `length` /// # Implementation diff --git a/crates/polars-arrow/src/compute/aggregate/memory.rs b/crates/polars-arrow/src/compute/aggregate/memory.rs index d78ed4d23f501..8b59503b93e73 100644 --- a/crates/polars-arrow/src/compute/aggregate/memory.rs +++ b/crates/polars-arrow/src/compute/aggregate/memory.rs @@ -1,8 +1,8 @@ use crate::array::*; use crate::bitmap::Bitmap; use crate::datatypes::PhysicalType; +pub use crate::types::PrimitiveType; use crate::{match_integer_type, with_match_primitive_type_full}; - fn validity_size(validity: Option<&Bitmap>) -> usize { validity.as_ref().map(|b| b.as_slice().0.len()).unwrap_or(0) } @@ -24,9 +24,9 @@ macro_rules! dyn_binary { } fn binview_size(array: &BinaryViewArrayGeneric) -> usize { - array.views().len() * std::mem::size_of::() - + array.data_buffers().iter().map(|b| b.len()).sum::() - + validity_size(array.validity()) + // We choose the optimal usage as data can be shared across buffers. + // If we would sum all buffers we overestimate memory usage and trigger OOC when not needed. + array.total_bytes_len() } /// Returns the total (heap) allocated size of the array in bytes. @@ -48,6 +48,10 @@ pub fn estimated_bytes_size(array: &dyn Array) -> usize { let array = array.as_any().downcast_ref::().unwrap(); array.values().as_slice().0.len() + validity_size(array.validity()) }, + Primitive(PrimitiveType::DaysMs) => { + let array = array.as_any().downcast_ref::().unwrap(); + array.values().len() * std::mem::size_of::() * 2 + validity_size(array.validity()) + }, Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| { let array = array .as_any() diff --git a/crates/polars-arrow/src/compute/aggregate/simd/mod.rs b/crates/polars-arrow/src/compute/aggregate/simd/mod.rs index ea359fc592f5e..010ba336fc372 100644 --- a/crates/polars-arrow/src/compute/aggregate/simd/mod.rs +++ b/crates/polars-arrow/src/compute/aggregate/simd/mod.rs @@ -40,7 +40,8 @@ macro_rules! simd_add { }; } -pub(super) use simd_add; +// #[cfg(not(feature = "simd"))] +// pub(super) use simd_add; simd_add!(i128x8, i128, 8, add); diff --git a/crates/polars-arrow/src/compute/aggregate/simd/native.rs b/crates/polars-arrow/src/compute/aggregate/simd/native.rs index 01382ecbbc8f4..eb33878decbd6 100644 --- a/crates/polars-arrow/src/compute/aggregate/simd/native.rs +++ b/crates/polars-arrow/src/compute/aggregate/simd/native.rs @@ -1,7 +1,6 @@ use std::ops::Add; use super::super::sum::Sum; -use super::simd_add; use crate::types::simd::*; simd_add!(u8x64, u8, 64, wrapping_add); diff --git a/crates/polars-arrow/src/compute/cast/binview_to.rs b/crates/polars-arrow/src/compute/cast/binview_to.rs index f3c0a7de2b7c8..daab95aac2e1f 100644 --- a/crates/polars-arrow/src/compute/cast/binview_to.rs +++ b/crates/polars-arrow/src/compute/cast/binview_to.rs @@ -13,6 +13,28 @@ use crate::types::NativeType; pub(super) const RFC3339: &str = "%Y-%m-%dT%H:%M:%S%.f%:z"; +/// Cast [`BinaryViewArray`] to [`DictionaryArray`], also known as packing. +/// # Errors +/// This function errors if the maximum key is smaller than the number of distinct elements +/// in the array. +pub(super) fn binview_to_dictionary( + from: &BinaryViewArray, +) -> PolarsResult> { + let mut array = MutableDictionaryArray::>::new(); + array.try_extend(from.iter())?; + + Ok(array.into()) +} + +pub(super) fn utf8view_to_dictionary( + from: &Utf8ViewArray, +) -> PolarsResult> { + let mut array = MutableDictionaryArray::>::new(); + array.try_extend(from.iter())?; + + Ok(array.into()) +} + pub(super) fn view_to_binary(array: &BinaryViewArray) -> BinaryArray { let len: usize = Array::len(array); let mut mutable = MutableBinaryValuesArray::::with_capacities(len, array.total_bytes_len()); diff --git a/crates/polars-arrow/src/compute/cast/mod.rs b/crates/polars-arrow/src/compute/cast/mod.rs index 015eac0606eab..4b1c53c5de43e 100644 --- a/crates/polars-arrow/src/compute/cast/mod.rs +++ b/crates/polars-arrow/src/compute/cast/mod.rs @@ -17,15 +17,16 @@ pub use boolean_to::*; pub use decimal_to::*; pub use dictionary_to::*; use polars_error::{polars_bail, polars_ensure, polars_err, PolarsResult}; +use polars_utils::IdxSize; pub use primitive_to::*; pub use utf8_to::*; use crate::array::*; use crate::compute::cast::binview_to::{ - utf8view_to_date32_dyn, utf8view_to_naive_timestamp_dyn, view_to_binary, + binview_to_dictionary, utf8view_to_date32_dyn, utf8view_to_dictionary, + utf8view_to_naive_timestamp_dyn, view_to_binary, }; use crate::datatypes::*; -use crate::legacy::index::IdxSize; use crate::match_integer_type; use crate::offset::{Offset, Offsets}; use crate::temporal_conversions::utf8view_to_timestamp; @@ -293,6 +294,12 @@ pub fn cast( (Struct(_), _) | (_, Struct(_)) => polars_bail!(InvalidOperation: "Cannot cast from struct to other types" ), + (Dictionary(index_type, ..), _) => match_integer_type!(index_type, |$T| { + dictionary_cast_dyn::<$T>(array, to_type, options) + }), + (_, Dictionary(index_type, value_type, _)) => match_integer_type!(index_type, |$T| { + cast_to_dictionary::<$T>(array, value_type, options) + }), // not supported by polars // (List(_), FixedSizeList(inner, size)) => cast_list_to_fixed_size_list::( // array.as_any().downcast_ref().unwrap(), @@ -320,11 +327,6 @@ pub fn cast( options, ) .map(|x| x.boxed()), - // not supported by polars - // (List(_), List(_)) => { - // cast_list::(array.as_any().downcast_ref().unwrap(), to_type, options) - // .map(|x| x.boxed()) - // }, (BinaryView, _) => match to_type { Utf8View => array .as_any() @@ -430,12 +432,6 @@ pub fn cast( } }, - (Dictionary(index_type, ..), _) => match_integer_type!(index_type, |$T| { - dictionary_cast_dyn::<$T>(array, to_type, options) - }), - (_, Dictionary(index_type, value_type, _)) => match_integer_type!(index_type, |$T| { - cast_to_dictionary::<$T>(array, value_type, options) - }), (_, Boolean) => match from_type { UInt8 => primitive_to_boolean_dyn::(array, to_type.clone()), UInt16 => primitive_to_boolean_dyn::(array, to_type.clone()), @@ -447,6 +443,7 @@ pub fn cast( Int64 => primitive_to_boolean_dyn::(array, to_type.clone()), Float32 => primitive_to_boolean_dyn::(array, to_type.clone()), Float64 => primitive_to_boolean_dyn::(array, to_type.clone()), + Decimal(_, _) => primitive_to_boolean_dyn::(array, to_type.clone()), _ => polars_bail!(InvalidOperation: "casting from {from_type:?} to {to_type:?} not supported", ), @@ -774,6 +771,14 @@ fn cast_to_dictionary( ArrowDataType::UInt16 => primitive_to_dictionary_dyn::(array), ArrowDataType::UInt32 => primitive_to_dictionary_dyn::(array), ArrowDataType::UInt64 => primitive_to_dictionary_dyn::(array), + ArrowDataType::BinaryView => { + binview_to_dictionary::(array.as_any().downcast_ref().unwrap()) + .map(|arr| arr.boxed()) + }, + ArrowDataType::Utf8View => { + utf8view_to_dictionary::(array.as_any().downcast_ref().unwrap()) + .map(|arr| arr.boxed()) + }, ArrowDataType::LargeUtf8 => utf8_to_dictionary_dyn::(array), ArrowDataType::LargeBinary => binary_to_dictionary_dyn::(array), ArrowDataType::Time64(_) => primitive_to_dictionary_dyn::(array), diff --git a/crates/polars-arrow/src/compute/take/bitmap.rs b/crates/polars-arrow/src/compute/take/bitmap.rs index a57f575170a41..75bee16f0abbb 100644 --- a/crates/polars-arrow/src/compute/take/bitmap.rs +++ b/crates/polars-arrow/src/compute/take/bitmap.rs @@ -1,8 +1,11 @@ +use polars_utils::IdxSize; + +use crate::array::Array; use crate::bitmap::Bitmap; -use crate::legacy::index::IdxSize; +use crate::datatypes::IdxArr; /// # Safety -/// doesn't do any bound checks +/// Doesn't do any bound checks. pub unsafe fn take_bitmap_unchecked(values: &Bitmap, indices: &[IdxSize]) -> Bitmap { let values = indices.iter().map(|&index| { debug_assert!((index as usize) < values.len()); @@ -10,3 +13,26 @@ pub unsafe fn take_bitmap_unchecked(values: &Bitmap, indices: &[IdxSize]) -> Bit }); Bitmap::from_trusted_len_iter(values) } + +/// # Safety +/// Doesn't check bounds for non-null elements. +pub unsafe fn take_bitmap_nulls_unchecked(values: &Bitmap, indices: &IdxArr) -> Bitmap { + // Fast-path: no need to bother with null indices. + if indices.null_count() == 0 { + return take_bitmap_unchecked(values, indices.values()); + } + + if values.is_empty() { + // Nothing can be in-bounds, assume indices is full-null. + debug_assert!(indices.null_count() == indices.len()); + return Bitmap::new_zeroed(indices.len()); + } + + let values = indices.iter().map(|opt_index| { + // We checked that values.len() > 0 so we can use index 0 for nulls. + let index = opt_index.copied().unwrap_or(0) as usize; + debug_assert!(index < values.len()); + values.get_bit_unchecked(index) + }); + Bitmap::from_trusted_len_iter(values) +} diff --git a/crates/polars-arrow/src/compute/take/boolean.rs b/crates/polars-arrow/src/compute/take/boolean.rs index 049a3c4d5d9f1..3e6008d546527 100644 --- a/crates/polars-arrow/src/compute/take/boolean.rs +++ b/crates/polars-arrow/src/compute/take/boolean.rs @@ -1,14 +1,15 @@ -use super::bitmap::take_bitmap_unchecked; +use polars_utils::IdxSize; + +use super::bitmap::{take_bitmap_nulls_unchecked, take_bitmap_unchecked}; use crate::array::{Array, BooleanArray, PrimitiveArray}; use crate::bitmap::{Bitmap, MutableBitmap}; -use crate::legacy::index::IdxSize; -// take implementation when neither values nor indices contain nulls +// Take implementation when neither values nor indices contain nulls. unsafe fn take_no_validity(values: &Bitmap, indices: &[IdxSize]) -> (Bitmap, Option) { (take_bitmap_unchecked(values, indices), None) } -// take implementation when only values contain nulls +// Take implementation when only values contain nulls. unsafe fn take_values_validity( values: &BooleanArray, indices: &[IdxSize], @@ -22,18 +23,16 @@ unsafe fn take_values_validity( (buffer, validity.into()) } -// take implementation when only indices contain nulls +// Take implementation when only indices contain nulls. unsafe fn take_indices_validity( values: &Bitmap, indices: &PrimitiveArray, ) -> (Bitmap, Option) { - // simply take all and copy the bitmap - let buffer = take_bitmap_unchecked(values, indices.values()); - + let buffer = take_bitmap_nulls_unchecked(values, indices); (buffer, indices.validity().cloned()) } -// take implementation when both values and indices contain nulls +// Take implementation when both values and indices contain nulls. unsafe fn take_values_indices_validity( values: &BooleanArray, indices: &PrimitiveArray, diff --git a/crates/polars-arrow/src/compute/take/structure.rs b/crates/polars-arrow/src/compute/take/structure.rs index bd9be54dc4b0c..3619dae307bb6 100644 --- a/crates/polars-arrow/src/compute/take/structure.rs +++ b/crates/polars-arrow/src/compute/take/structure.rs @@ -28,7 +28,7 @@ pub(super) unsafe fn take_unchecked(array: &StructArray, indices: &IdxArr) -> St let validity = array .validity() - .map(|b| super::bitmap::take_bitmap_unchecked(b, indices.values())); + .map(|b| super::bitmap::take_bitmap_nulls_unchecked(b, indices)); let validity = combine_validities_and(validity.as_ref(), indices.validity()); StructArray::new(array.data_type().clone(), values, validity) } diff --git a/crates/polars-arrow/src/compute/utils.rs b/crates/polars-arrow/src/compute/utils.rs index edac9c8032d0b..744d12d2fe690 100644 --- a/crates/polars-arrow/src/compute/utils.rs +++ b/crates/polars-arrow/src/compute/utils.rs @@ -1,6 +1,6 @@ use std::ops::{BitAnd, BitOr}; -use polars_error::{polars_bail, polars_ensure, PolarsResult}; +use polars_error::{polars_ensure, PolarsResult}; use crate::array::Array; use crate::bitmap::{ternary, Bitmap}; diff --git a/crates/polars-arrow/src/ffi/array.rs b/crates/polars-arrow/src/ffi/array.rs index 4efeffee9b0ee..5de9176e39807 100644 --- a/crates/polars-arrow/src/ffi/array.rs +++ b/crates/polars-arrow/src/ffi/array.rs @@ -93,6 +93,7 @@ struct PrivateData { impl ArrowArray { /// creates a new `ArrowArray` from existing data. + /// /// # Safety /// This method releases `buffers`. Consumers of this struct *must* call `release` before /// releasing this struct, or contents in `buffers` leak. @@ -411,7 +412,8 @@ unsafe fn buffer_len( }) } -/// Safety +/// # Safety +/// /// This function is safe iff: /// * `array.children` at `index` is valid /// * `array.children` is not mutably shared for the lifetime of `parent` @@ -453,7 +455,8 @@ unsafe fn create_child( Ok(ArrowArrayChild::new(arr_ptr, data_type, parent)) } -/// Safety +/// # Safety +/// /// This function is safe iff: /// * `array.dictionary` is valid /// * `array.dictionary` is not mutably shared for the lifetime of `parent` @@ -488,6 +491,7 @@ pub trait ArrowArrayRef: std::fmt::Debug { /// returns the null bit buffer. /// Rust implementation uses a buffer that is not part of the array of buffers. /// The C Data interface's null buffer is part of the array of buffers. + /// /// # Safety /// The caller must guarantee that the buffer `index` corresponds to a bitmap. /// This function assumes that the bitmap created from FFI is valid; this is impossible to prove. diff --git a/crates/polars-arrow/src/ffi/schema.rs b/crates/polars-arrow/src/ffi/schema.rs index 09e09e0494b35..23cf9c8c4a479 100644 --- a/crates/polars-arrow/src/ffi/schema.rs +++ b/crates/polars-arrow/src/ffi/schema.rs @@ -1,5 +1,4 @@ use std::collections::BTreeMap; -use std::convert::TryInto; use std::ffi::{CStr, CString}; use std::ptr; diff --git a/crates/polars-arrow/src/ffi/stream.rs b/crates/polars-arrow/src/ffi/stream.rs index 89855e1cd2019..58a0b07855290 100644 --- a/crates/polars-arrow/src/ffi/stream.rs +++ b/crates/polars-arrow/src/ffi/stream.rs @@ -54,6 +54,7 @@ impl> ArrowArrayStreamReader { /// # Error /// Errors iff the [`ArrowArrayStream`] is out of specification, /// or was already released prior to calling this function. + /// /// # Safety /// This method is intrinsically `unsafe` since it assumes that the `ArrowArrayStream` /// contains a valid Arrow C stream interface. @@ -101,6 +102,7 @@ impl> ArrowArrayStreamReader { /// Errors iff: /// * The C stream interface returns an error /// * The C stream interface returns an invalid array (that we can identify, see Safety below) + /// /// # Safety /// Calling this iterator's `next` assumes that the [`ArrowArrayStream`] produces arrow arrays /// that fulfill the C data interface diff --git a/crates/polars-arrow/src/io/avro/read/deserialize.rs b/crates/polars-arrow/src/io/avro/read/deserialize.rs index 3eaa556f56726..a2c8c83cc9b0a 100644 --- a/crates/polars-arrow/src/io/avro/read/deserialize.rs +++ b/crates/polars-arrow/src/io/avro/read/deserialize.rs @@ -1,5 +1,3 @@ -use std::convert::TryInto; - use avro_schema::file::Block; use avro_schema::schema::{Enum, Field as AvroField, Record, Schema as AvroSchema}; use polars_error::{polars_bail, polars_err, PolarsResult}; diff --git a/crates/polars-arrow/src/io/ipc/read/array/binview.rs b/crates/polars-arrow/src/io/ipc/read/array/binview.rs index 40905c740e970..8d57250237911 100644 --- a/crates/polars-arrow/src/io/ipc/read/array/binview.rs +++ b/crates/polars-arrow/src/io/ipc/read/array/binview.rs @@ -1,14 +1,12 @@ -use std::collections::VecDeque; use std::io::{Read, Seek}; use std::sync::Arc; -use polars_error::{polars_err, PolarsResult}; +use polars_error::polars_err; use super::super::read_basic::*; use super::*; use crate::array::{ArrayRef, BinaryViewArrayGeneric, View, ViewType}; use crate::buffer::Buffer; -use crate::datatypes::ArrowDataType; #[allow(clippy::too_many_arguments)] pub fn read_binview( @@ -67,3 +65,34 @@ pub fn read_binview( BinaryViewArrayGeneric::::try_new(data_type, views, Arc::from(variadic_buffers), validity) .map(|arr| arr.boxed()) } + +pub fn skip_binview( + field_nodes: &mut VecDeque, + buffers: &mut VecDeque, + variadic_buffer_counts: &mut VecDeque, +) -> PolarsResult<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + polars_err!( + oos = "IPC: unable to fetch the field for utf8. The file or stream is corrupted." + ) + })?; + + let _ = buffers + .pop_front() + .ok_or_else(|| polars_err!(oos = "IPC: missing validity buffer."))?; + + let _ = buffers + .pop_front() + .ok_or_else(|| polars_err!(oos = "IPC: missing views buffer."))?; + + let n_variadic = variadic_buffer_counts.pop_front().ok_or_else( + || polars_err!(ComputeError: "IPC: unable to fetch the variadic buffers\n\nThe file or stream is corrupted.") + )?; + + for _ in 0..n_variadic { + let _ = buffers + .pop_front() + .ok_or_else(|| polars_err!(oos = "IPC: missing variadic buffer"))?; + } + Ok(()) +} diff --git a/crates/polars-arrow/src/io/ipc/read/array/dictionary.rs b/crates/polars-arrow/src/io/ipc/read/array/dictionary.rs index 846a2a8ea8fa5..5a43fe21e102b 100644 --- a/crates/polars-arrow/src/io/ipc/read/array/dictionary.rs +++ b/crates/polars-arrow/src/io/ipc/read/array/dictionary.rs @@ -1,5 +1,4 @@ use std::collections::VecDeque; -use std::convert::TryInto; use std::io::{Read, Seek}; use ahash::HashSet; diff --git a/crates/polars-arrow/src/io/ipc/read/array/fixed_size_list.rs b/crates/polars-arrow/src/io/ipc/read/array/fixed_size_list.rs index 335a426d0e44d..1f303a1567871 100644 --- a/crates/polars-arrow/src/io/ipc/read/array/fixed_size_list.rs +++ b/crates/polars-arrow/src/io/ipc/read/array/fixed_size_list.rs @@ -66,6 +66,7 @@ pub fn skip_fixed_size_list( field_nodes: &mut VecDeque, data_type: &ArrowDataType, buffers: &mut VecDeque, + variadic_buffer_counts: &mut VecDeque, ) -> PolarsResult<()> { let _ = field_nodes.pop_front().ok_or_else(|| { polars_err!(oos = @@ -79,5 +80,10 @@ pub fn skip_fixed_size_list( let (field, _) = FixedSizeListArray::get_child_and_size(data_type); - skip(field_nodes, field.data_type(), buffers) + skip( + field_nodes, + field.data_type(), + buffers, + variadic_buffer_counts, + ) } diff --git a/crates/polars-arrow/src/io/ipc/read/array/list.rs b/crates/polars-arrow/src/io/ipc/read/array/list.rs index c36646fe01929..45566fd5df9fc 100644 --- a/crates/polars-arrow/src/io/ipc/read/array/list.rs +++ b/crates/polars-arrow/src/io/ipc/read/array/list.rs @@ -1,5 +1,4 @@ use std::collections::VecDeque; -use std::convert::TryInto; use std::io::{Read, Seek}; use polars_error::{polars_err, PolarsResult}; @@ -86,6 +85,7 @@ pub fn skip_list( field_nodes: &mut VecDeque, data_type: &ArrowDataType, buffers: &mut VecDeque, + variadic_buffer_counts: &mut VecDeque, ) -> PolarsResult<()> { let _ = field_nodes.pop_front().ok_or_else(|| { polars_err!( @@ -102,5 +102,5 @@ pub fn skip_list( let data_type = ListArray::::get_child_type(data_type); - skip(field_nodes, data_type, buffers) + skip(field_nodes, data_type, buffers, variadic_buffer_counts) } diff --git a/crates/polars-arrow/src/io/ipc/read/array/map.rs b/crates/polars-arrow/src/io/ipc/read/array/map.rs index 2301085136b21..741d496a5a633 100644 --- a/crates/polars-arrow/src/io/ipc/read/array/map.rs +++ b/crates/polars-arrow/src/io/ipc/read/array/map.rs @@ -81,6 +81,7 @@ pub fn skip_map( field_nodes: &mut VecDeque, data_type: &ArrowDataType, buffers: &mut VecDeque, + variadic_buffer_counts: &mut VecDeque, ) -> PolarsResult<()> { let _ = field_nodes.pop_front().ok_or_else(|| { polars_err!( @@ -97,5 +98,5 @@ pub fn skip_map( let data_type = MapArray::get_field(data_type).data_type(); - skip(field_nodes, data_type, buffers) + skip(field_nodes, data_type, buffers, variadic_buffer_counts) } diff --git a/crates/polars-arrow/src/io/ipc/read/array/primitive.rs b/crates/polars-arrow/src/io/ipc/read/array/primitive.rs index 24b2a05ec6a4f..04304aadca901 100644 --- a/crates/polars-arrow/src/io/ipc/read/array/primitive.rs +++ b/crates/polars-arrow/src/io/ipc/read/array/primitive.rs @@ -1,5 +1,4 @@ use std::collections::VecDeque; -use std::convert::TryInto; use std::io::{Read, Seek}; use polars_error::{polars_err, PolarsResult}; diff --git a/crates/polars-arrow/src/io/ipc/read/array/struct_.rs b/crates/polars-arrow/src/io/ipc/read/array/struct_.rs index b90ba11a40287..6dc716ab368bc 100644 --- a/crates/polars-arrow/src/io/ipc/read/array/struct_.rs +++ b/crates/polars-arrow/src/io/ipc/read/array/struct_.rs @@ -71,6 +71,7 @@ pub fn skip_struct( field_nodes: &mut VecDeque, data_type: &ArrowDataType, buffers: &mut VecDeque, + variadic_buffer_counts: &mut VecDeque, ) -> PolarsResult<()> { let _ = field_nodes.pop_front().ok_or_else(|| { polars_err!( @@ -84,7 +85,12 @@ pub fn skip_struct( let fields = StructArray::get_fields(data_type); - fields - .iter() - .try_for_each(|field| skip(field_nodes, field.data_type(), buffers)) + fields.iter().try_for_each(|field| { + skip( + field_nodes, + field.data_type(), + buffers, + variadic_buffer_counts, + ) + }) } diff --git a/crates/polars-arrow/src/io/ipc/read/array/union.rs b/crates/polars-arrow/src/io/ipc/read/array/union.rs index 00409ef58e689..192d9582ed213 100644 --- a/crates/polars-arrow/src/io/ipc/read/array/union.rs +++ b/crates/polars-arrow/src/io/ipc/read/array/union.rs @@ -97,6 +97,7 @@ pub fn skip_union( field_nodes: &mut VecDeque, data_type: &ArrowDataType, buffers: &mut VecDeque, + variadic_buffer_counts: &mut VecDeque, ) -> PolarsResult<()> { let _ = field_nodes.pop_front().ok_or_else(|| { polars_err!( @@ -117,7 +118,12 @@ pub fn skip_union( let fields = UnionArray::get_fields(data_type); - fields - .iter() - .try_for_each(|field| skip(field_nodes, field.data_type(), buffers)) + fields.iter().try_for_each(|field| { + skip( + field_nodes, + field.data_type(), + buffers, + variadic_buffer_counts, + ) + }) } diff --git a/crates/polars-arrow/src/io/ipc/read/array/utf8.rs b/crates/polars-arrow/src/io/ipc/read/array/utf8.rs index 1408ff41435e1..f29f8d8cdb26f 100644 --- a/crates/polars-arrow/src/io/ipc/read/array/utf8.rs +++ b/crates/polars-arrow/src/io/ipc/read/array/utf8.rs @@ -1,13 +1,11 @@ -use std::collections::VecDeque; use std::io::{Read, Seek}; -use polars_error::{polars_err, PolarsResult}; +use polars_error::polars_err; use super::super::read_basic::*; use super::*; use crate::array::Utf8Array; use crate::buffer::Buffer; -use crate::datatypes::ArrowDataType; use crate::offset::Offset; #[allow(clippy::too_many_arguments)] diff --git a/crates/polars-arrow/src/io/ipc/read/common.rs b/crates/polars-arrow/src/io/ipc/read/common.rs index 87005dc76cc49..32aa535241734 100644 --- a/crates/polars-arrow/src/io/ipc/read/common.rs +++ b/crates/polars-arrow/src/io/ipc/read/common.rs @@ -2,7 +2,6 @@ use std::collections::VecDeque; use std::io::{Read, Seek}; use ahash::AHashMap; -use arrow_format; use polars_error::{polars_bail, polars_err, PolarsResult}; use super::deserialize::{read, skip}; @@ -150,7 +149,12 @@ pub fn read_record_batch( scratch, )?)), ProjectionResult::NotSelected((field, _)) => { - skip(&mut field_nodes, &field.data_type, &mut buffers)?; + skip( + &mut field_nodes, + &field.data_type, + &mut buffers, + &mut variadic_buffer_counts, + )?; Ok(None) }, }) diff --git a/crates/polars-arrow/src/io/ipc/read/deserialize.rs b/crates/polars-arrow/src/io/ipc/read/deserialize.rs index 972884c0af3fb..f27d9b58100e0 100644 --- a/crates/polars-arrow/src/io/ipc/read/deserialize.rs +++ b/crates/polars-arrow/src/io/ipc/read/deserialize.rs @@ -263,6 +263,7 @@ pub fn skip( field_nodes: &mut VecDeque, data_type: &ArrowDataType, buffers: &mut VecDeque, + variadic_buffer_counts: &mut VecDeque, ) -> PolarsResult<()> { use PhysicalType::*; match data_type.to_physical_type() { @@ -272,13 +273,15 @@ pub fn skip( LargeBinary | Binary => skip_binary(field_nodes, buffers), LargeUtf8 | Utf8 => skip_utf8(field_nodes, buffers), FixedSizeBinary => skip_fixed_size_binary(field_nodes, buffers), - List => skip_list::(field_nodes, data_type, buffers), - LargeList => skip_list::(field_nodes, data_type, buffers), - FixedSizeList => skip_fixed_size_list(field_nodes, data_type, buffers), - Struct => skip_struct(field_nodes, data_type, buffers), + List => skip_list::(field_nodes, data_type, buffers, variadic_buffer_counts), + LargeList => skip_list::(field_nodes, data_type, buffers, variadic_buffer_counts), + FixedSizeList => { + skip_fixed_size_list(field_nodes, data_type, buffers, variadic_buffer_counts) + }, + Struct => skip_struct(field_nodes, data_type, buffers, variadic_buffer_counts), Dictionary(_) => skip_dictionary(field_nodes, buffers), - Union => skip_union(field_nodes, data_type, buffers), - Map => skip_map(field_nodes, data_type, buffers), - BinaryView | Utf8View => todo!(), + Union => skip_union(field_nodes, data_type, buffers, variadic_buffer_counts), + Map => skip_map(field_nodes, data_type, buffers, variadic_buffer_counts), + BinaryView | Utf8View => skip_binview(field_nodes, buffers, variadic_buffer_counts), } } diff --git a/crates/polars-arrow/src/io/ipc/read/file.rs b/crates/polars-arrow/src/io/ipc/read/file.rs index 6f1f4ca8f511a..0eb60d0a566cc 100644 --- a/crates/polars-arrow/src/io/ipc/read/file.rs +++ b/crates/polars-arrow/src/io/ipc/read/file.rs @@ -1,8 +1,8 @@ -use std::convert::TryInto; use std::io::{Read, Seek, SeekFrom}; use std::sync::Arc; use arrow_format::ipc::planus::ReadAsRoot; +use arrow_format::ipc::FooterRef; use polars_error::{polars_bail, polars_err, PolarsResult}; use polars_utils::aliases::{InitHashMaps, PlHashMap}; @@ -62,6 +62,22 @@ fn read_dictionary_message( Ok(()) } +/// Read the row count by summing the length of the of the record batches +pub fn get_row_count(reader: &mut R) -> PolarsResult { + let mut message_scratch: Vec = Default::default(); + let (_, footer_len) = read_footer_len(reader)?; + let footer = read_footer(reader, footer_len)?; + let (_, blocks) = deserialize_footer_blocks(&footer)?; + + blocks + .into_iter() + .map(|block| { + let message = get_message_from_block(reader, block, &mut message_scratch)?; + let record_batch = get_record_batch(message)?; + record_batch.length().map_err(|e| e.into()) + }) + .sum() +} pub(crate) fn get_dictionary_batch<'a>( message: &'a arrow_format::ipc::MessageRef, @@ -152,6 +168,9 @@ fn read_footer_len(reader: &mut R) -> PolarsResult<(u64, usize)> let footer_len = i32::from_le_bytes(footer[..4].try_into().unwrap()); if footer[4..] != ARROW_MAGIC_V2 { + if footer[..4] == ARROW_MAGIC_V1 { + polars_bail!(ComputeError: "feather v1 not supported"); + } return Err(polars_err!(oos = OutOfSpecKind::InvalidFooter)); } let footer_len = footer_len @@ -161,7 +180,22 @@ fn read_footer_len(reader: &mut R) -> PolarsResult<(u64, usize)> Ok((end, footer_len)) } -pub(super) fn deserialize_footer(footer_data: &[u8], size: u64) -> PolarsResult { +fn read_footer(reader: &mut R, footer_len: usize) -> PolarsResult> { + // read footer + reader.seek(SeekFrom::End(-10 - footer_len as i64))?; + + let mut serialized_footer = vec![]; + serialized_footer.try_reserve(footer_len)?; + reader + .by_ref() + .take(footer_len as u64) + .read_to_end(&mut serialized_footer)?; + Ok(serialized_footer) +} + +fn deserialize_footer_blocks( + footer_data: &[u8], +) -> PolarsResult<(FooterRef, Vec)> { let footer = arrow_format::ipc::FooterRef::read_as_root(footer_data) .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferFooter(err)))?; @@ -178,6 +212,11 @@ pub(super) fn deserialize_footer(footer_data: &[u8], size: u64) -> PolarsResult< }) }) .collect::>>()?; + Ok((footer, blocks)) +} + +pub fn deserialize_footer(footer_data: &[u8], size: u64) -> PolarsResult { + let (footer, blocks) = deserialize_footer_blocks(footer_data)?; let ipc_schema = footer .schema() @@ -211,29 +250,9 @@ pub(super) fn deserialize_footer(footer_data: &[u8], size: u64) -> PolarsResult< /// Read the Arrow IPC file's metadata pub fn read_file_metadata(reader: &mut R) -> PolarsResult { - // check if header contain the correct magic bytes - let mut magic_buffer: [u8; 6] = [0; 6]; let start = reader.stream_position()?; - reader.read_exact(&mut magic_buffer)?; - if magic_buffer != ARROW_MAGIC_V2 { - if magic_buffer[..4] == ARROW_MAGIC_V1 { - polars_bail!(ComputeError: "feather v1 not supported"); - } - polars_bail!(oos = OutOfSpecKind::InvalidHeader); - } - let (end, footer_len) = read_footer_len(reader)?; - - // read footer - reader.seek(SeekFrom::End(-10 - footer_len as i64))?; - - let mut serialized_footer = vec![]; - serialized_footer.try_reserve(footer_len)?; - reader - .by_ref() - .take(footer_len as u64) - .read_to_end(&mut serialized_footer)?; - + let serialized_footer = read_footer(reader, footer_len)?; deserialize_footer(&serialized_footer, end - start) } @@ -250,6 +269,47 @@ pub(crate) fn get_record_batch( } } +fn get_message_from_block_offset<'a, R: Read + Seek>( + reader: &mut R, + offset: u64, + message_scratch: &'a mut Vec, +) -> PolarsResult> { + // read length + reader.seek(SeekFrom::Start(offset))?; + let mut meta_buf = [0; 4]; + reader.read_exact(&mut meta_buf)?; + if meta_buf == CONTINUATION_MARKER { + // continuation marker encountered, read message next + reader.read_exact(&mut meta_buf)?; + } + let meta_len = i32::from_le_bytes(meta_buf) + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::UnexpectedNegativeInteger))?; + + message_scratch.clear(); + message_scratch.try_reserve(meta_len)?; + reader + .by_ref() + .take(meta_len as u64) + .read_to_end(message_scratch)?; + + arrow_format::ipc::MessageRef::read_as_root(message_scratch) + .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferMessage(err))) +} + +fn get_message_from_block<'a, R: Read + Seek>( + reader: &mut R, + block: arrow_format::ipc::Block, + message_scratch: &'a mut Vec, +) -> PolarsResult> { + let offset: u64 = block + .offset + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + + get_message_from_block_offset(reader, offset, message_scratch) +} + /// Reads the record batch at position `index` from the reader. /// /// This function is useful for random access to the file. For example, if @@ -280,28 +340,7 @@ pub fn read_batch( .try_into() .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; - // read length - reader.seek(SeekFrom::Start(offset))?; - let mut meta_buf = [0; 4]; - reader.read_exact(&mut meta_buf)?; - if meta_buf == CONTINUATION_MARKER { - // continuation marker encountered, read message next - reader.read_exact(&mut meta_buf)?; - } - let meta_len = i32::from_le_bytes(meta_buf) - .try_into() - .map_err(|_| polars_err!(oos = OutOfSpecKind::UnexpectedNegativeInteger))?; - - message_scratch.clear(); - message_scratch.try_reserve(meta_len)?; - reader - .by_ref() - .take(meta_len as u64) - .read_to_end(message_scratch)?; - - let message = arrow_format::ipc::MessageRef::read_as_root(message_scratch.as_ref()) - .map_err(|err| polars_err!(oos = OutOfSpecKind::InvalidFlatbufferMessage(err)))?; - + let message = get_message_from_block_offset(reader, offset, message_scratch)?; let batch = get_record_batch(message)?; read_record_batch( diff --git a/crates/polars-arrow/src/io/ipc/read/mod.rs b/crates/polars-arrow/src/io/ipc/read/mod.rs index 3688816273e52..74d9a93a93095 100644 --- a/crates/polars-arrow/src/io/ipc/read/mod.rs +++ b/crates/polars-arrow/src/io/ipc/read/mod.rs @@ -17,6 +17,7 @@ mod schema; mod stream; pub use error::OutOfSpecKind; +pub use file::get_row_count; #[cfg(feature = "io_ipc_read_async")] #[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_read_async")))] @@ -29,7 +30,9 @@ pub mod file_async; pub(crate) use common::first_dict_field; #[cfg(feature = "io_flight")] pub(crate) use common::{read_dictionary, read_record_batch}; -pub use file::{read_batch, read_file_dictionaries, read_file_metadata, FileMetadata}; +pub use file::{ + deserialize_footer, read_batch, read_file_dictionaries, read_file_metadata, FileMetadata, +}; use polars_utils::aliases::PlHashMap; pub use reader::FileReader; pub use schema::deserialize_schema; diff --git a/crates/polars-arrow/src/io/ipc/read/read_basic.rs b/crates/polars-arrow/src/io/ipc/read/read_basic.rs index 3864b24bf26c8..09005ea4222e6 100644 --- a/crates/polars-arrow/src/io/ipc/read/read_basic.rs +++ b/crates/polars-arrow/src/io/ipc/read/read_basic.rs @@ -1,5 +1,4 @@ use std::collections::VecDeque; -use std::convert::TryInto; use std::io::{Read, Seek, SeekFrom}; use polars_error::{polars_bail, polars_err, PolarsResult}; @@ -97,12 +96,12 @@ fn read_uncompressed_buffer( fn read_compressed_buffer( reader: &mut R, buffer_length: usize, - length: usize, + output_length: Option, is_little_endian: bool, compression: Compression, scratch: &mut Vec, ) -> PolarsResult> { - if length == 0 { + if output_length == Some(0) { return Ok(vec![]); } @@ -112,10 +111,6 @@ fn read_compressed_buffer( ) } - // It is undefined behavior to call read_exact on un-initialized, https://doc.rust-lang.org/std/io/trait.Read.html#tymethod.read - // see also https://github.com/MaikKlein/ash/issues/354#issue-781730580 - let mut buffer = vec![T::default(); length]; - // decompress first scratch.clear(); scratch.try_reserve(buffer_length)?; @@ -124,6 +119,13 @@ fn read_compressed_buffer( .take(buffer_length as u64) .read_to_end(scratch)?; + let length = output_length + .unwrap_or_else(|| i64::from_le_bytes(scratch[..8].try_into().unwrap()) as usize); + + // It is undefined behavior to call read_exact on un-initialized, https://doc.rust-lang.org/std/io/trait.Read.html#tymethod.read + // see also https://github.com/MaikKlein/ash/issues/354#issue-781730580 + let mut buffer = vec![T::default(); length]; + let out_slice = bytemuck::cast_slice_mut(&mut buffer); let compression = compression @@ -151,7 +153,7 @@ fn read_compressed_bytes( read_compressed_buffer::( reader, buffer_length, - buffer_length, + None, is_little_endian, compression, scratch, @@ -225,7 +227,7 @@ pub fn read_buffer( Ok(read_compressed_buffer( reader, buffer_length, - length, + Some(length), is_little_endian, compression, scratch, diff --git a/crates/polars-arrow/src/io/ipc/read/stream.rs b/crates/polars-arrow/src/io/ipc/read/stream.rs index 5fab1d8262119..72b0763642946 100644 --- a/crates/polars-arrow/src/io/ipc/read/stream.rs +++ b/crates/polars-arrow/src/io/ipc/read/stream.rs @@ -1,7 +1,6 @@ use std::io::Read; use ahash::AHashMap; -use arrow_format; use arrow_format::ipc::planus::ReadAsRoot; use polars_error::{polars_bail, polars_err, PolarsError, PolarsResult}; diff --git a/crates/polars-arrow/src/io/ipc/write/schema.rs b/crates/polars-arrow/src/io/ipc/write/schema.rs index 41e88b29f7ea3..5aefef3e6684a 100644 --- a/crates/polars-arrow/src/io/ipc/write/schema.rs +++ b/crates/polars-arrow/src/io/ipc/write/schema.rs @@ -38,18 +38,13 @@ pub fn serialize_schema( .map(|(field, ipc_field)| serialize_field(field, ipc_field)) .collect::>(); - let mut custom_metadata = vec![]; - for (key, value) in &schema.metadata { - custom_metadata.push(arrow_format::ipc::KeyValue { - key: Some(key.clone()), - value: Some(value.clone()), - }); - } - let custom_metadata = if custom_metadata.is_empty() { - None - } else { - Some(custom_metadata) - }; + let custom_metadata = schema + .metadata + .iter() + .map(|(k, v)| key_value(k, v)) + .collect::>(); + + let custom_metadata = (!custom_metadata.is_empty()).then_some(custom_metadata); arrow_format::ipc::Schema { endianness, @@ -59,14 +54,17 @@ pub fn serialize_schema( } } +fn key_value(key: impl Into, val: impl Into) -> arrow_format::ipc::KeyValue { + arrow_format::ipc::KeyValue { + key: Some(key.into()), + value: Some(val.into()), + } +} + fn write_metadata(metadata: &Metadata, kv_vec: &mut Vec) { for (k, v) in metadata { if k != "ARROW:extension:name" && k != "ARROW:extension:metadata" { - let entry = arrow_format::ipc::KeyValue { - key: Some(k.clone()), - value: Some(v.clone()), - }; - kv_vec.push(entry); + kv_vec.push(key_value(k, v)); } } } @@ -76,21 +74,11 @@ fn write_extension( metadata: &Option, kv_vec: &mut Vec, ) { - // metadata if let Some(metadata) = metadata { - let entry = arrow_format::ipc::KeyValue { - key: Some("ARROW:extension:metadata".to_string()), - value: Some(metadata.clone()), - }; - kv_vec.push(entry); + kv_vec.push(key_value("ARROW:extension:metadata", metadata)); } - // name - let entry = arrow_format::ipc::KeyValue { - key: Some("ARROW:extension:name".to_string()), - value: Some(name.to_string()), - }; - kv_vec.push(entry); + kv_vec.push(key_value("ARROW:extension:name", name)); } /// Create an IPC Field from an Arrow Field diff --git a/crates/polars-arrow/src/legacy/array/slice.rs b/crates/polars-arrow/src/legacy/array/slice.rs index 63997e78b88c7..720723c901a83 100644 --- a/crates/polars-arrow/src/legacy/array/slice.rs +++ b/crates/polars-arrow/src/legacy/array/slice.rs @@ -15,6 +15,7 @@ pub trait SlicedArray { /// Slices the [`Array`]. /// # Implementation /// This operation is `O(1)`. + /// /// # Safety /// The caller must ensure that `offset + length <= self.len()` unsafe fn slice_typed_unchecked(&self, offset: usize, length: usize) -> Self diff --git a/crates/polars-arrow/src/legacy/index.rs b/crates/polars-arrow/src/legacy/index.rs index f0c874ad197dc..0c6ce660eff52 100644 --- a/crates/polars-arrow/src/legacy/index.rs +++ b/crates/polars-arrow/src/legacy/index.rs @@ -1,9 +1,7 @@ use num_traits::{NumCast, Signed, Zero}; +use polars_utils::IdxSize; -#[cfg(not(feature = "bigidx"))] -use crate::array::UInt32Array; -#[cfg(feature = "bigidx")] -use crate::array::UInt64Array; +use crate::array::PrimitiveArray; pub trait IndexToUsize { /// Translate the negative index to an offset. @@ -33,17 +31,8 @@ where } } -/// The type used by polars to index data. -#[cfg(not(feature = "bigidx"))] -pub type IdxSize = u32; -#[cfg(feature = "bigidx")] -pub type IdxSize = u64; - -#[cfg(not(feature = "bigidx"))] -pub type IdxArr = UInt32Array; -#[cfg(feature = "bigidx")] -pub type IdxArr = UInt64Array; - pub fn indexes_to_usizes(idx: &[IdxSize]) -> impl Iterator + '_ { idx.iter().map(|idx| *idx as usize) } + +pub type IdxArr = PrimitiveArray; diff --git a/crates/polars-arrow/src/legacy/is_valid.rs b/crates/polars-arrow/src/legacy/is_valid.rs index bff6095a20b2f..692937e5c60e6 100644 --- a/crates/polars-arrow/src/legacy/is_valid.rs +++ b/crates/polars-arrow/src/legacy/is_valid.rs @@ -19,11 +19,8 @@ impl ArrowArray for ListArray {} impl ArrowArray for FixedSizeListArray {} impl IsValid for A { + #[inline] unsafe fn is_valid_unchecked(&self, i: usize) -> bool { - if let Some(b) = self.validity() { - b.get_bit_unchecked(i) - } else { - true - } + !self.is_null_unchecked(i) } } diff --git a/crates/polars-arrow/src/legacy/kernels/fixed_size_list.rs b/crates/polars-arrow/src/legacy/kernels/fixed_size_list.rs index 7f36e92cc14ef..3b1495b20570a 100644 --- a/crates/polars-arrow/src/legacy/kernels/fixed_size_list.rs +++ b/crates/polars-arrow/src/legacy/kernels/fixed_size_list.rs @@ -1,3 +1,5 @@ +use polars_utils::IdxSize; + use crate::array::{ArrayRef, FixedSizeListArray, PrimitiveArray}; use crate::compute::take::take_unchecked; use crate::legacy::prelude::*; diff --git a/crates/polars-arrow/src/legacy/kernels/list.rs b/crates/polars-arrow/src/legacy/kernels/list.rs index 63fad4bf60c84..e67d1638e99d4 100644 --- a/crates/polars-arrow/src/legacy/kernels/list.rs +++ b/crates/polars-arrow/src/legacy/kernels/list.rs @@ -1,3 +1,5 @@ +use polars_utils::IdxSize; + use crate::array::{ArrayRef, ListArray}; use crate::compute::take::take_unchecked; use crate::legacy::prelude::*; diff --git a/crates/polars-arrow/src/legacy/kernels/mod.rs b/crates/polars-arrow/src/legacy/kernels/mod.rs index 2c93ea0eca9d8..59f346f639edc 100644 --- a/crates/polars-arrow/src/legacy/kernels/mod.rs +++ b/crates/polars-arrow/src/legacy/kernels/mod.rs @@ -7,8 +7,10 @@ pub mod agg_mean; pub mod atan2; pub mod concatenate; pub mod ewm; +#[cfg(feature = "compute_take")] pub mod fixed_size_list; pub mod float; +#[cfg(feature = "compute_take")] pub mod list; pub mod list_bytes_iter; pub mod pow; diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mean.rs b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mean.rs index 775c8c1f60d59..f74f88248b2f1 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mean.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mean.rs @@ -1,7 +1,5 @@ -use no_nulls::{rolling_apply_agg_window, RollingAggWindowNoNulls}; use polars_error::polars_ensure; -use super::sum::SumWindow; use super::*; pub struct MeanWindow<'a, T> { @@ -19,9 +17,9 @@ impl< } } - unsafe fn update(&mut self, start: usize, end: usize) -> T { - let sum = self.sum.update(start, end); - sum / NumCast::from(end - start).unwrap() + unsafe fn update(&mut self, start: usize, end: usize) -> Option { + let sum = self.sum.update(start, end).unwrap_unchecked(); + Some(sum / NumCast::from(end - start).unwrap()) } } diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/min_max.rs b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/min_max.rs index efeeb9e183a2d..54fe9a927dde4 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/min_max.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/min_max.rs @@ -1,6 +1,3 @@ -use no_nulls; -use no_nulls::{rolling_apply_agg_window, RollingAggWindowNoNulls}; - use super::*; #[inline] @@ -151,7 +148,7 @@ macro_rules! minmax_window { } } - unsafe fn update(&mut self, start: usize, end: usize) -> T { + unsafe fn update(&mut self, start: usize, end: usize) -> Option { //For details see: https://github.com/pola-rs/polars/pull/9277#issuecomment-1581401692 self.last_start = start; // Don't care where the last one started let old_last_end = self.last_end; // But we need this @@ -171,10 +168,10 @@ macro_rules! minmax_window { if entering.map(|em| $new_is_m(&self.m, em.1) || empty_overlap) == Some(true) { // The entering extremum "beats" the previous extremum so we can ignore the overlap self.update_m_and_m_idx(entering.unwrap()); - return self.m; + return Some(self.m); } else if self.m_idx >= start || empty_overlap { // The previous extremum didn't drop off. Keep it - return self.m; + return Some(self.m); } // Otherwise get the min of the overlapping window and the entering min match ( @@ -194,7 +191,7 @@ macro_rules! minmax_window { (None, None) => unreachable!(), } - self.m + Some(self.m) } } }; @@ -244,7 +241,7 @@ macro_rules! rolling_minmax_func { _params: DynArgs, ) -> PolarsResult where - T: NativeType + PartialOrd + IsFloat + Bounded + NumCast + Mul, + T: NativeType + PartialOrd + IsFloat + Bounded + NumCast + Mul + Num, { let offset_fn = match center { true => det_offsets_center, diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs index d680061f60a6f..ffe04bcfd5984 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs @@ -3,12 +3,11 @@ mod min_max; mod quantile; mod sum; mod variance; - use std::fmt::Debug; pub use mean::*; pub use min_max::*; -use num_traits::{Float, NumCast}; +use num_traits::{Float, Num, NumCast}; pub use quantile::*; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -18,17 +17,17 @@ pub use variance::*; use super::*; use crate::array::PrimitiveArray; use crate::datatypes::ArrowDataType; -use crate::legacy::error::{polars_bail, PolarsResult}; -use crate::legacy::utils::CustomIterTools; +use crate::legacy::error::PolarsResult; use crate::types::NativeType; pub trait RollingAggWindowNoNulls<'a, T: NativeType> { fn new(slice: &'a [T], start: usize, end: usize, params: DynArgs) -> Self; /// Update and recompute the window + /// /// # Safety /// `start` and `end` must be within the windows bounds - unsafe fn update(&mut self, start: usize, end: usize) -> T; + unsafe fn update(&mut self, start: usize, end: usize) -> Option; } // Use an aggregation window that maintains the state @@ -42,27 +41,34 @@ pub(super) fn rolling_apply_agg_window<'a, Agg, T, Fo>( where Fo: Fn(Idx, WindowSize, Len) -> (Start, End), Agg: RollingAggWindowNoNulls<'a, T>, - T: Debug + NativeType, + T: Debug + NativeType + Num, { let len = values.len(); let (start, end) = det_offsets_fn(0, window_size, len); let mut agg_window = Agg::new(values, start, end, params); + if let Some(validity) = create_validity(min_periods, len, window_size, &det_offsets_fn) { + if validity.iter().all(|x| !x) { + return Ok(Box::new(PrimitiveArray::::new_null( + T::PRIMITIVE.into(), + len, + ))); + } + } let out = (0..len) .map(|idx| { let (start, end) = det_offsets_fn(idx, window_size, len); - // SAFETY: - // we are in bounds - unsafe { agg_window.update(start, end) } + if end - start < min_periods { + None + } else { + // SAFETY: + // we are in bounds + unsafe { agg_window.update(start, end) } + } }) .collect_trusted::>(); - - let validity = create_validity(min_periods, len, window_size, det_offsets_fn); - Ok(Box::new(PrimitiveArray::new( - T::PRIMITIVE.into(), - out.into(), - validity.map(|b| b.into()), - ))) + let arr = PrimitiveArray::from(out); + Ok(Box::new(arr)) } #[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash)] diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/quantile.rs b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/quantile.rs index 92b78eb12b608..50b7702bdbddc 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/quantile.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/quantile.rs @@ -1,5 +1,3 @@ -use std::fmt::Debug; - use num_traits::ToPrimitive; use polars_error::polars_ensure; use polars_utils::slice::GetSaferUnchecked; @@ -37,7 +35,7 @@ impl< } } - unsafe fn update(&mut self, start: usize, end: usize) -> T { + unsafe fn update(&mut self, start: usize, end: usize) -> Option { let vals = self.sorted.update(start, end); let length = vals.len(); @@ -50,13 +48,13 @@ impl< let float_idx_top = (length_f - 1.0) * self.prob; let top_idx = float_idx_top.ceil() as usize; return if idx == top_idx { - unsafe { *vals.get_unchecked_release(idx) } + Some(unsafe { *vals.get_unchecked_release(idx) }) } else { let proportion = T::from(float_idx_top - idx as f64).unwrap(); let vi = unsafe { *vals.get_unchecked_release(idx) }; let vj = unsafe { *vals.get_unchecked_release(top_idx) }; - proportion * (vj - vi) + vi + Some(proportion * (vj - vi) + vi) }; }, Midpoint => { @@ -68,7 +66,7 @@ impl< return if top_idx == idx { // SAFETY: // we are in bounds - unsafe { *vals.get_unchecked_release(idx) } + Some(unsafe { *vals.get_unchecked_release(idx) }) } else { // SAFETY: // we are in bounds @@ -79,7 +77,7 @@ impl< ) }; - (mid + mid_plus_1) / (T::one() + T::one()) + Some((mid + mid_plus_1) / (T::one() + T::one())) }; }, Nearest => { @@ -95,7 +93,7 @@ impl< // SAFETY: // we are in bounds - unsafe { *vals.get_unchecked_release(idx) } + Some(unsafe { *vals.get_unchecked_release(idx) }) } } @@ -261,7 +259,6 @@ where #[cfg(test)] mod test { use super::*; - use crate::legacy::kernels::rolling::no_nulls::{rolling_max, rolling_min}; #[test] fn test_rolling_median() { diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/sum.rs b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/sum.rs index ef5efa52d1efe..b66a3a4fc5ed1 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/sum.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/sum.rs @@ -1,6 +1,3 @@ -use no_nulls; -use no_nulls::{rolling_apply_agg_window, RollingAggWindowNoNulls}; - use super::*; pub struct SumWindow<'a, T> { @@ -23,7 +20,7 @@ impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign> } } - unsafe fn update(&mut self, start: usize, end: usize) -> T { + unsafe fn update(&mut self, start: usize, end: usize) -> Option { // if we exceed the end, we have a completely new window // so we recompute let recompute_sum = if start >= self.last_end { @@ -63,7 +60,7 @@ impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign> } } self.last_end = end; - self.sum + Some(self.sum) } } @@ -76,7 +73,14 @@ pub fn rolling_sum( _params: DynArgs, ) -> PolarsResult where - T: NativeType + std::iter::Sum + NumCast + Mul + AddAssign + SubAssign + IsFloat, + T: NativeType + + std::iter::Sum + + NumCast + + Mul + + AddAssign + + SubAssign + + IsFloat + + Num, { match (center, weights) { (true, None) => rolling_apply_agg_window::, _, _>( diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/variance.rs b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/variance.rs index 64273e91b467f..4e3de45cfeff3 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/variance.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/variance.rs @@ -1,7 +1,5 @@ -use no_nulls::{rolling_apply_agg_window, RollingAggWindowNoNulls}; use polars_error::polars_ensure; -use super::mean::MeanWindow; use super::*; pub(super) struct SumSquaredWindow<'a, T> { @@ -28,7 +26,7 @@ impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign + Mul< } } - unsafe fn update(&mut self, start: usize, end: usize) -> T { + unsafe fn update(&mut self, start: usize, end: usize) -> Option { // if we exceed the end, we have a completely new window // so we recompute let recompute_sum = if start >= self.last_end || self.last_recompute > 128 { @@ -70,7 +68,7 @@ impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign + Mul< } } self.last_end = end; - self.sum_of_squares + Some(self.sum_of_squares) } } @@ -110,25 +108,24 @@ impl< } } - unsafe fn update(&mut self, start: usize, end: usize) -> T { + unsafe fn update(&mut self, start: usize, end: usize) -> Option { let count: T = NumCast::from(end - start).unwrap(); - let sum_of_squares = self.sum_of_squares.update(start, end); - let mean = self.mean.update(start, end); + let sum_of_squares = self.sum_of_squares.update(start, end).unwrap_unchecked(); + let mean = self.mean.update(start, end).unwrap_unchecked(); let denom = count - NumCast::from(self.ddof).unwrap(); - if end - start == 1 { - T::zero() - } else if denom <= T::zero() { - //ddof would be greater than # of observations - T::infinity() + if denom <= T::zero() { + None + } else if end - start == 1 { + Some(T::zero()) } else { let out = (sum_of_squares - count * mean * mean) / denom; // variance cannot be negative. // if it is negative it is due to numeric instability if out < T::zero() { - T::zero() + Some(T::zero()) } else { - out + Some(out) } } } @@ -210,14 +207,11 @@ mod test { let out = rolling_var(values, 2, 1, false, None, None).unwrap(); let out = out.as_any().downcast_ref::>().unwrap(); - let out = out - .into_iter() - .map(|v| v.copied().unwrap()) - .collect::>(); + let out = out.into_iter().map(|v| v.copied()).collect::>(); // we cannot compare nans, so we compare the string values assert_eq!( format!("{:?}", out.as_slice()), - format!("{:?}", &[0.0, 8.0, 2.0, 0.5]) + format!("{:?}", &[None, Some(8.0), Some(2.0), Some(0.5)]) ); // test nan handling. let values = &[-10.0, 2.0, 3.0, f64::nan(), 5.0, 6.0, 7.0]; diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/mean.rs b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/mean.rs index 4e2d915b09f1b..cbd3c7e981d44 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/mean.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/mean.rs @@ -1,5 +1,4 @@ -use super::sum::SumWindow; -use super::{rolling_apply_agg_window, RollingAggWindowNulls, *}; +use super::*; pub struct MeanWindow<'a, T> { sum: SumWindow<'a, T>, diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/min_max.rs b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/min_max.rs index 334c187e84140..55ea003f5c5a8 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/min_max.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/min_max.rs @@ -1,6 +1,3 @@ -use nulls; -use nulls::{rolling_apply_agg_window, RollingAggWindowNulls}; - use super::*; use crate::array::iterator::NonNullValuesIter; use crate::bitmap::utils::count_zeros; diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/mod.rs b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/mod.rs index 26037f10794a9..4174c5b498261 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/mod.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/mod.rs @@ -94,7 +94,6 @@ mod test { use crate::array::{Array, Int32Array}; use crate::buffer::Buffer; use crate::datatypes::ArrowDataType; - use crate::legacy::kernels::rolling::nulls::mean::rolling_mean; fn get_null_arr() -> PrimitiveArray { // 1, None, -1, 4 diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/quantile.rs b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/quantile.rs index 5db89c320fb57..f10616547b251 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/quantile.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/quantile.rs @@ -155,7 +155,6 @@ mod test { use super::*; use crate::buffer::Buffer; use crate::datatypes::ArrowDataType; - use crate::legacy::kernels::rolling::nulls::{rolling_max, rolling_min}; #[test] fn test_rolling_median_nulls() { diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/sum.rs b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/sum.rs index 9c5d2125edaa9..876f60187a791 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/sum.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/sum.rs @@ -1,6 +1,3 @@ -use nulls; -use nulls::{rolling_apply_agg_window, RollingAggWindowNulls}; - use super::*; pub struct SumWindow<'a, T> { diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/variance.rs b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/variance.rs index bc4b3da5b0cd5..1793fc32a4b75 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/variance.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/variance.rs @@ -1,7 +1,3 @@ -use mean::MeanWindow; -use nulls; -use nulls::{rolling_apply_agg_window, RollingAggWindowNulls}; - use super::*; pub(super) struct SumSquaredWindow<'a, T> { diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/window.rs b/crates/polars-arrow/src/legacy/kernels/rolling/window.rs index 3a874762b5534..9ce4b586a8365 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/window.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/window.rs @@ -24,6 +24,7 @@ impl<'a, T: NativeType> SortedBuf<'a, T> { } /// Update the window position by setting the `start` index and the `end` index. + /// /// # Safety /// The caller must ensure that `start` and `end` are within bounds of `self.slice` /// @@ -120,6 +121,7 @@ impl<'a, T: NativeType> SortedBufNulls<'a, T> { } /// Update the window position by setting the `start` index and the `end` index. + /// /// # Safety /// The caller must ensure that `start` and `end` are within bounds of `self.slice` /// diff --git a/crates/polars-arrow/src/legacy/kernels/set.rs b/crates/polars-arrow/src/legacy/kernels/set.rs index 4fd87905bb742..41f3dbcf5c3d4 100644 --- a/crates/polars-arrow/src/legacy/kernels/set.rs +++ b/crates/polars-arrow/src/legacy/kernels/set.rs @@ -1,12 +1,12 @@ use std::ops::BitOr; use polars_error::polars_err; +use polars_utils::IdxSize; use crate::array::*; use crate::datatypes::ArrowDataType; use crate::legacy::array::default_arrays::FromData; use crate::legacy::error::PolarsResult; -use crate::legacy::index::IdxSize; use crate::legacy::kernels::BinaryMaskedSliceIterator; use crate::legacy::trusted_len::TrustedLenPush; use crate::types::NativeType; @@ -97,10 +97,7 @@ where #[cfg(test)] mod test { - use std::iter::FromIterator; - use super::*; - use crate::array::UInt32Array; #[test] fn test_set_mask() { diff --git a/crates/polars-arrow/src/legacy/kernels/sort_partition.rs b/crates/polars-arrow/src/legacy/kernels/sort_partition.rs index 3021e9a330c0d..cf1c8c56402a3 100644 --- a/crates/polars-arrow/src/legacy/kernels/sort_partition.rs +++ b/crates/polars-arrow/src/legacy/kernels/sort_partition.rs @@ -1,6 +1,7 @@ use std::fmt::Debug; -use crate::legacy::index::IdxSize; +use polars_utils::IdxSize; + use crate::types::NativeType; /// Find partition indexes such that every partition contains unique groups. diff --git a/crates/polars-arrow/src/legacy/kernels/sorted_join/left.rs b/crates/polars-arrow/src/legacy/kernels/sorted_join/left.rs index 5f87df216a695..3d41574bcbf82 100644 --- a/crates/polars-arrow/src/legacy/kernels/sorted_join/left.rs +++ b/crates/polars-arrow/src/legacy/kernels/sorted_join/left.rs @@ -11,7 +11,7 @@ pub fn join( if right.is_empty() { return ( (left_offset..left.len() as IdxSize + left_offset).collect(), - vec![None; left.len()], + vec![NullableIdxSize::null(); left.len()], ); } // * 1.5 because there can be duplicates @@ -27,7 +27,7 @@ pub fn join( let first_right = right[right_idx as usize]; let mut left_idx = left.partition_point(|v| v < &first_right) as IdxSize; - out_rhs.extend(std::iter::repeat(None).take(left_idx as usize)); + out_rhs.extend(std::iter::repeat(NullableIdxSize::null()).take(left_idx as usize)); out_lhs.extend(left_offset..(left_idx + left_offset)); for &val_l in &left[left_idx as usize..] { @@ -37,7 +37,7 @@ pub fn join( // matching join key if val_l == val_r { out_lhs.push(left_idx + left_offset); - out_rhs.push(Some(right_idx)); + out_rhs.push(right_idx.into()); let current_idx = right_idx; loop { @@ -52,7 +52,7 @@ pub fn join( Some(&val_r) => { if val_l == val_r { out_lhs.push(left_idx + left_offset); - out_rhs.push(Some(right_idx)); + out_rhs.push(right_idx.into()); } else { // reset right index because the next lhs value can be the same right_idx = current_idx; @@ -67,7 +67,7 @@ pub fn join( // right is larger than left. if val_r > val_l { out_lhs.push(left_idx + left_offset); - out_rhs.push(None); + out_rhs.push(NullableIdxSize::null()); break; } // continue looping the right side @@ -76,7 +76,7 @@ pub fn join( // we depleted the right array None => { out_lhs.push(left_idx + left_offset); - out_rhs.push(None); + out_rhs.push(NullableIdxSize::null()); break; }, } @@ -98,14 +98,14 @@ mod test { let (l_idx, r_idx) = join(lhs, rhs, 0); let out_left = &[0, 1, 1, 2, 2, 3, 4, 5]; let out_right = &[ - Some(0), - Some(1), - Some(2), - Some(1), - Some(2), - None, - Some(3), - None, + 0.into(), + 1.into(), + 2.into(), + 1.into(), + 2.into(), + NullableIdxSize::null(), + 3.into(), + NullableIdxSize::null(), ]; assert_eq!(&l_idx, out_left); assert_eq!(&r_idx, out_right); @@ -128,21 +128,21 @@ mod test { assert_eq!( &r_idx, &[ - Some(0), - Some(1), - Some(0), - Some(1), - Some(2), - Some(3), - Some(4), - None, - Some(5), - Some(6), - Some(5), - Some(6), - Some(5), - Some(6), - None + 0.into(), + 1.into(), + 0.into(), + 1.into(), + 2.into(), + 3.into(), + 4.into(), + NullableIdxSize::null(), + 5.into(), + 6.into(), + 5.into(), + 6.into(), + 5.into(), + 6.into(), + NullableIdxSize::null(), ] ); @@ -153,16 +153,16 @@ mod test { assert_eq!( &r_idx, &[ - None, - None, - Some(1), - Some(2), - Some(2), - Some(2), - Some(2), - Some(3), - Some(4), - Some(4) + NullableIdxSize::null(), + NullableIdxSize::null(), + 1.into(), + 2.into(), + 2.into(), + 2.into(), + 2.into(), + 3.into(), + 4.into(), + 4.into() ] ); let lhs = &[0, 1, 2, 2, 3, 4, 4, 6, 6, 7]; @@ -172,20 +172,20 @@ mod test { assert_eq!( &r_idx, &[ - None, - None, - None, - None, - None, - Some(0), - Some(1), - Some(2), - Some(0), - Some(1), - Some(2), - None, - None, - None + NullableIdxSize::null(), + NullableIdxSize::null(), + NullableIdxSize::null(), + NullableIdxSize::null(), + NullableIdxSize::null(), + 0.into(), + 1.into(), + 2.into(), + 0.into(), + 1.into(), + 2.into(), + NullableIdxSize::null(), + NullableIdxSize::null(), + NullableIdxSize::null(), ] ) } diff --git a/crates/polars-arrow/src/legacy/kernels/sorted_join/mod.rs b/crates/polars-arrow/src/legacy/kernels/sorted_join/mod.rs index 0ba3c648ff45b..5aea170f30d59 100644 --- a/crates/polars-arrow/src/legacy/kernels/sorted_join/mod.rs +++ b/crates/polars-arrow/src/legacy/kernels/sorted_join/mod.rs @@ -3,9 +3,9 @@ pub mod left; use std::fmt::Debug; -use crate::legacy::index::IdxSize; +use polars_utils::{IdxSize, NullableIdxSize}; -type JoinOptIds = Vec>; +type JoinOptIds = Vec; type JoinIds = Vec; type LeftJoinIds = (JoinIds, JoinOptIds); type InnerJoinIds = (JoinIds, JoinIds); diff --git a/crates/polars-arrow/src/legacy/kernels/take_agg/mod.rs b/crates/polars-arrow/src/legacy/kernels/take_agg/mod.rs index 77213de4d8bf5..9daf37389837c 100644 --- a/crates/polars-arrow/src/legacy/kernels/take_agg/mod.rs +++ b/crates/polars-arrow/src/legacy/kernels/take_agg/mod.rs @@ -4,10 +4,10 @@ mod var; pub use boolean::*; use num_traits::{NumCast, ToPrimitive}; +use polars_utils::IdxSize; pub use var::*; use crate::array::{Array, BinaryViewArray, BooleanArray, PrimitiveArray}; -use crate::legacy::index::IdxSize; use crate::types::NativeType; /// Take kernel for single chunk without nulls and an iterator as index. diff --git a/crates/polars-arrow/src/legacy/kernels/time.rs b/crates/polars-arrow/src/legacy/kernels/time.rs index ab20c57b80e20..365f4fd5d3168 100644 --- a/crates/polars-arrow/src/legacy/kernels/time.rs +++ b/crates/polars-arrow/src/legacy/kernels/time.rs @@ -11,6 +11,7 @@ use polars_error::{polars_bail, PolarsError}; pub enum Ambiguous { Earliest, Latest, + Null, Raise, } impl FromStr for Ambiguous { @@ -21,8 +22,9 @@ impl FromStr for Ambiguous { "earliest" => Ok(Ambiguous::Earliest), "latest" => Ok(Ambiguous::Latest), "raise" => Ok(Ambiguous::Raise), + "null" => Ok(Ambiguous::Null), s => polars_bail!(InvalidOperation: - "Invalid argument {}, expected one of: \"earliest\", \"latest\", \"raise\"", s + "Invalid argument {}, expected one of: \"earliest\", \"latest\", \"null\", \"raise\"", s ), } } @@ -34,13 +36,14 @@ pub fn convert_to_naive_local( to_tz: &Tz, ndt: NaiveDateTime, ambiguous: Ambiguous, -) -> PolarsResult { +) -> PolarsResult> { let ndt = from_tz.from_utc_datetime(&ndt).naive_local(); match to_tz.from_local_datetime(&ndt) { - LocalResult::Single(dt) => Ok(dt.naive_utc()), + LocalResult::Single(dt) => Ok(Some(dt.naive_utc())), LocalResult::Ambiguous(dt_earliest, dt_latest) => match ambiguous { - Ambiguous::Earliest => Ok(dt_earliest.naive_utc()), - Ambiguous::Latest => Ok(dt_latest.naive_utc()), + Ambiguous::Earliest => Ok(Some(dt_earliest.naive_utc())), + Ambiguous::Latest => Ok(Some(dt_latest.naive_utc())), + Ambiguous::Null => Ok(None), Ambiguous::Raise => { polars_bail!(ComputeError: "datetime '{}' is ambiguous in time zone '{}'. Please use `ambiguous` to tell how it should be localized.", ndt, to_tz) }, @@ -52,19 +55,22 @@ pub fn convert_to_naive_local( } } +/// Same as convert_to_naive_local, but return `None` instead +/// raising - in some cases this can be used to save a string allocation. #[cfg(feature = "timezones")] pub fn convert_to_naive_local_opt( from_tz: &Tz, to_tz: &Tz, ndt: NaiveDateTime, ambiguous: Ambiguous, -) -> Option { +) -> Option> { let ndt = from_tz.from_utc_datetime(&ndt).naive_local(); match to_tz.from_local_datetime(&ndt) { - LocalResult::Single(dt) => Some(dt.naive_utc()), + LocalResult::Single(dt) => Some(Some(dt.naive_utc())), LocalResult::Ambiguous(dt_earliest, dt_latest) => match ambiguous { - Ambiguous::Earliest => Some(dt_earliest.naive_utc()), - Ambiguous::Latest => Some(dt_latest.naive_utc()), + Ambiguous::Earliest => Some(Some(dt_earliest.naive_utc())), + Ambiguous::Latest => Some(Some(dt_latest.naive_utc())), + Ambiguous::Null => Some(None), Ambiguous::Raise => None, }, LocalResult::None => None, diff --git a/crates/polars-arrow/src/offset.rs b/crates/polars-arrow/src/offset.rs index e96cf30a025c9..ae8f568b3eb17 100644 --- a/crates/polars-arrow/src/offset.rs +++ b/crates/polars-arrow/src/offset.rs @@ -141,6 +141,7 @@ impl Offsets { } /// Returns [`Offsets`] assuming that `offsets` fulfills its invariants + /// /// # Safety /// This is safe iff the invariants of this struct are guaranteed in `offsets`. #[inline] @@ -168,6 +169,7 @@ impl Offsets { } /// Returns a range (start, end) corresponding to the position `index` + /// /// # Safety /// `index` must be `< self.len()` #[inline] @@ -441,6 +443,7 @@ impl OffsetsBuffer { } /// Returns a range (start, end) corresponding to the position `index` + /// /// # Safety /// `index` must be `< self.len()` #[inline] @@ -462,6 +465,7 @@ impl OffsetsBuffer { } /// Slices this [`OffsetsBuffer`] starting at `offset`. + /// /// # Safety /// The caller must ensure `offset + length <= self.len()` #[inline] diff --git a/crates/polars-arrow/src/pushable.rs b/crates/polars-arrow/src/pushable.rs index db71d8726a8a2..12d04ebdcf762 100644 --- a/crates/polars-arrow/src/pushable.rs +++ b/crates/polars-arrow/src/pushable.rs @@ -15,6 +15,7 @@ pub trait Pushable: Sized + Default { fn len(&self) -> usize; fn push_null(&mut self); fn extend_constant(&mut self, additional: usize, value: T); + fn extend_null_constant(&mut self, additional: usize); } impl Pushable for MutableBitmap { @@ -41,6 +42,11 @@ impl Pushable for MutableBitmap { fn extend_constant(&mut self, additional: usize, value: bool) { self.extend_constant(additional, value) } + + #[inline] + fn extend_null_constant(&mut self, additional: usize) { + self.extend_constant(additional, false) + } } impl Pushable for Vec { @@ -67,6 +73,11 @@ impl Pushable for Vec { fn extend_constant(&mut self, additional: usize, value: T) { self.resize(self.len() + additional, value); } + + #[inline] + fn extend_null_constant(&mut self, additional: usize) { + self.extend_constant(additional, T::default()) + } } impl Pushable for Offsets { fn reserve(&mut self, additional: usize) { @@ -91,6 +102,11 @@ impl Pushable for Offsets { fn extend_constant(&mut self, additional: usize, _: usize) { self.extend_constant(additional) } + + #[inline] + fn extend_null_constant(&mut self, additional: usize) { + self.extend_constant(additional) + } } impl Pushable> for MutablePrimitiveArray { @@ -118,6 +134,11 @@ impl Pushable> for MutablePrimitiveArray { fn extend_constant(&mut self, additional: usize, value: Option) { MutablePrimitiveArray::extend_constant(self, additional, value) } + + #[inline] + fn extend_null_constant(&mut self, additional: usize) { + MutablePrimitiveArray::extend_constant(self, additional, None) + } } impl Pushable<&T> for MutableBinaryViewArray { @@ -157,4 +178,9 @@ impl Pushable<&T> for MutableBinaryViewArray { bitmap.extend_constant(remaining, true) } } + + #[inline] + fn extend_null_constant(&mut self, additional: usize) { + self.extend_null(additional); + } } diff --git a/crates/polars-arrow/src/scalar/equal.rs b/crates/polars-arrow/src/scalar/equal.rs index 0d02ca9c9d61d..c18d634559131 100644 --- a/crates/polars-arrow/src/scalar/equal.rs +++ b/crates/polars-arrow/src/scalar/equal.rs @@ -1,7 +1,6 @@ use std::sync::Arc; use super::*; -use crate::datatypes::PhysicalType; use crate::{match_integer_type, with_match_primitive_type}; impl PartialEq for dyn Scalar + '_ { @@ -53,6 +52,7 @@ fn equal(lhs: &dyn Scalar, rhs: &dyn Scalar) -> bool { FixedSizeList => dyn_eq!(FixedSizeListScalar, lhs, rhs), Union => dyn_eq!(UnionScalar, lhs, rhs), Map => dyn_eq!(MapScalar, lhs, rhs), + Utf8View => dyn_eq!(BinaryViewScalar, lhs, rhs), _ => unimplemented!(), } } diff --git a/crates/polars-arrow/src/temporal_conversions.rs b/crates/polars-arrow/src/temporal_conversions.rs index 8a9a792ed9929..75a77664dba79 100644 --- a/crates/polars-arrow/src/temporal_conversions.rs +++ b/crates/polars-arrow/src/temporal_conversions.rs @@ -1,7 +1,7 @@ //! Conversion methods for dates and times. use chrono::format::{parse, Parsed, StrftimeItems}; -use chrono::{Duration, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime}; +use chrono::{DateTime, Duration, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, TimeDelta}; use polars_error::{polars_err, PolarsResult}; use crate::array::{PrimitiveArray, Utf8ViewArray}; @@ -29,7 +29,8 @@ pub fn date32_to_datetime(v: i32) -> NaiveDateTime { /// converts a `i32` representing a `date32` to [`NaiveDateTime`] #[inline] pub fn date32_to_datetime_opt(v: i32) -> Option { - NaiveDateTime::from_timestamp_opt(v as i64 * SECONDS_IN_DAY, 0) + let delta = TimeDelta::try_days(v.into())?; + NaiveDateTime::UNIX_EPOCH.checked_add_signed(delta) } /// converts a `i32` representing a `date32` to [`NaiveDate`] @@ -47,13 +48,9 @@ pub fn date32_to_date_opt(days: i32) -> Option { /// converts a `i64` representing a `date64` to [`NaiveDateTime`] #[inline] pub fn date64_to_datetime(v: i64) -> NaiveDateTime { - NaiveDateTime::from_timestamp_opt( - // extract seconds from milliseconds - v / MILLISECONDS, - // discard extracted seconds and convert milliseconds to nanoseconds - (v % MILLISECONDS * MICROSECONDS) as u32, - ) - .expect("invalid or out-of-range datetime") + TimeDelta::try_milliseconds(v) + .and_then(|delta| NaiveDateTime::UNIX_EPOCH.checked_add_signed(delta)) + .expect("invalid or out-of-range datetime") } /// converts a `i64` representing a `date64` to [`NaiveDate`] @@ -71,13 +68,13 @@ pub fn time32s_to_time(v: i32) -> NaiveTime { /// converts a `i64` representing a `duration(s)` to [`Duration`] #[inline] pub fn duration_s_to_duration(v: i64) -> Duration { - Duration::seconds(v) + Duration::try_seconds(v).expect("out-of-range duration") } /// converts a `i64` representing a `duration(ms)` to [`Duration`] #[inline] pub fn duration_ms_to_duration(v: i64) -> Duration { - Duration::milliseconds(v) + Duration::try_milliseconds(v).expect("out-of-range in duration conversion") } /// converts a `i64` representing a `duration(us)` to [`Duration`] @@ -148,7 +145,7 @@ pub fn timestamp_s_to_datetime(seconds: i64) -> NaiveDateTime { /// converts a `i64` representing a `timestamp(s)` to [`NaiveDateTime`] #[inline] pub fn timestamp_s_to_datetime_opt(seconds: i64) -> Option { - NaiveDateTime::from_timestamp_opt(seconds, 0) + Some(DateTime::from_timestamp(seconds, 0)?.naive_utc()) } /// converts a `i64` representing a `timestamp(ms)` to [`NaiveDateTime`] @@ -160,27 +157,8 @@ pub fn timestamp_ms_to_datetime(v: i64) -> NaiveDateTime { /// converts a `i64` representing a `timestamp(ms)` to [`NaiveDateTime`] #[inline] pub fn timestamp_ms_to_datetime_opt(v: i64) -> Option { - if v >= 0 { - NaiveDateTime::from_timestamp_opt( - // extract seconds from milliseconds - v / MILLISECONDS, - // discard extracted seconds and convert milliseconds to nanoseconds - (v % MILLISECONDS * MICROSECONDS) as u32, - ) - } else { - let secs_rem = (v / MILLISECONDS, v % MILLISECONDS); - if secs_rem.1 == 0 { - // whole/integer seconds; no adjustment required - NaiveDateTime::from_timestamp_opt(secs_rem.0, 0) - } else { - // negative values with fractional seconds require 'div_floor' rounding behaviour. - // (which isn't yet stabilised: https://github.com/rust-lang/rust/issues/88581) - NaiveDateTime::from_timestamp_opt( - secs_rem.0 - 1, - (NANOSECONDS + (v % MILLISECONDS * MICROSECONDS)) as u32, - ) - } - } + let delta = TimeDelta::try_milliseconds(v)?; + NaiveDateTime::UNIX_EPOCH.checked_add_signed(delta) } /// converts a `i64` representing a `timestamp(us)` to [`NaiveDateTime`] @@ -192,27 +170,8 @@ pub fn timestamp_us_to_datetime(v: i64) -> NaiveDateTime { /// converts a `i64` representing a `timestamp(us)` to [`NaiveDateTime`] #[inline] pub fn timestamp_us_to_datetime_opt(v: i64) -> Option { - if v >= 0 { - NaiveDateTime::from_timestamp_opt( - // extract seconds from microseconds - v / MICROSECONDS, - // discard extracted seconds and convert microseconds to nanoseconds - (v % MICROSECONDS * MILLISECONDS) as u32, - ) - } else { - let secs_rem = (v / MICROSECONDS, v % MICROSECONDS); - if secs_rem.1 == 0 { - // whole/integer seconds; no adjustment required - NaiveDateTime::from_timestamp_opt(secs_rem.0, 0) - } else { - // negative values with fractional seconds require 'div_floor' rounding behaviour. - // (which isn't yet stabilised: https://github.com/rust-lang/rust/issues/88581) - NaiveDateTime::from_timestamp_opt( - secs_rem.0 - 1, - (NANOSECONDS + (v % MICROSECONDS * MILLISECONDS)) as u32, - ) - } - } + let delta = TimeDelta::microseconds(v); + NaiveDateTime::UNIX_EPOCH.checked_add_signed(delta) } /// converts a `i64` representing a `timestamp(ns)` to [`NaiveDateTime`] @@ -224,27 +183,8 @@ pub fn timestamp_ns_to_datetime(v: i64) -> NaiveDateTime { /// converts a `i64` representing a `timestamp(ns)` to [`NaiveDateTime`] #[inline] pub fn timestamp_ns_to_datetime_opt(v: i64) -> Option { - if v >= 0 { - NaiveDateTime::from_timestamp_opt( - // extract seconds from nanoseconds - v / NANOSECONDS, - // discard extracted seconds - (v % NANOSECONDS) as u32, - ) - } else { - let secs_rem = (v / NANOSECONDS, v % NANOSECONDS); - if secs_rem.1 == 0 { - // whole/integer seconds; no adjustment required - NaiveDateTime::from_timestamp_opt(secs_rem.0, 0) - } else { - // negative values with fractional seconds require 'div_floor' rounding behaviour. - // (which isn't yet stabilised: https://github.com/rust-lang/rust/issues/88581) - NaiveDateTime::from_timestamp_opt( - secs_rem.0 - 1, - (NANOSECONDS + (v % NANOSECONDS)) as u32, - ) - } - } + let delta = TimeDelta::nanoseconds(v); + NaiveDateTime::UNIX_EPOCH.checked_add_signed(delta) } /// Converts a timestamp in `time_unit` and `timezone` into [`chrono::DateTime`]. @@ -362,10 +302,10 @@ pub fn utf8_to_naive_timestamp_scalar(value: &str, fmt: &str, tu: &TimeUnit) -> parsed .to_naive_datetime_with_offset(0) .map(|x| match tu { - TimeUnit::Second => x.timestamp(), - TimeUnit::Millisecond => x.timestamp_millis(), - TimeUnit::Microsecond => x.timestamp_micros(), - TimeUnit::Nanosecond => x.timestamp_nanos_opt().unwrap(), + TimeUnit::Second => x.and_utc().timestamp(), + TimeUnit::Millisecond => x.and_utc().timestamp_millis(), + TimeUnit::Microsecond => x.and_utc().timestamp_micros(), + TimeUnit::Nanosecond => x.and_utc().timestamp_nanos_opt().unwrap(), }) .ok() } diff --git a/crates/polars-arrow/src/types/index.rs b/crates/polars-arrow/src/types/index.rs index 0aedea008fa39..83299a76980fc 100644 --- a/crates/polars-arrow/src/types/index.rs +++ b/crates/polars-arrow/src/types/index.rs @@ -1,5 +1,3 @@ -use std::convert::TryFrom; - use super::NativeType; use crate::trusted_len::TrustedLen; @@ -99,5 +97,7 @@ impl Iterator for IndexRange { } } -/// Safety: a range is always of known length +/// # Safety +/// +/// A range is always of known length. unsafe impl TrustedLen for IndexRange {} diff --git a/crates/polars-arrow/src/types/native.rs b/crates/polars-arrow/src/types/native.rs index 45d8d7cb665fb..95966004b8488 100644 --- a/crates/polars-arrow/src/types/native.rs +++ b/crates/polars-arrow/src/types/native.rs @@ -1,4 +1,3 @@ -use std::convert::TryFrom; use std::ops::Neg; use std::panic::RefUnwindSafe; diff --git a/crates/polars-arrow/src/types/simd/mod.rs b/crates/polars-arrow/src/types/simd/mod.rs index d906c9d25e950..2666abe2ba2c9 100644 --- a/crates/polars-arrow/src/types/simd/mod.rs +++ b/crates/polars-arrow/src/types/simd/mod.rs @@ -123,6 +123,7 @@ macro_rules! native_simd { }; } +#[cfg(not(feature = "simd"))] pub(super) use native_simd; // Types do not have specific intrinsics and thus SIMD can't be specialized. diff --git a/crates/polars-arrow/src/types/simd/native.rs b/crates/polars-arrow/src/types/simd/native.rs index af31b8b26bc0d..f0cb5436f4f30 100644 --- a/crates/polars-arrow/src/types/simd/native.rs +++ b/crates/polars-arrow/src/types/simd/native.rs @@ -1,7 +1,4 @@ -use std::convert::TryInto; - use super::*; -use crate::types::BitChunkIter; native_simd!(u8x64, u8, 64, u64); native_simd!(u16x32, u16, 32, u32); diff --git a/crates/polars-arrow/tests/it/ffi/mod.rs b/crates/polars-arrow/tests/it/ffi/mod.rs deleted file mode 100644 index 36d8589f579b4..0000000000000 --- a/crates/polars-arrow/tests/it/ffi/mod.rs +++ /dev/null @@ -1 +0,0 @@ -mod data; diff --git a/crates/polars-arrow/tests/it/main.rs b/crates/polars-arrow/tests/it/main.rs deleted file mode 100644 index a21dad004e519..0000000000000 --- a/crates/polars-arrow/tests/it/main.rs +++ /dev/null @@ -1,3 +0,0 @@ -mod ffi; -#[cfg(feature = "io_ipc_compression")] -mod io; diff --git a/crates/polars-compute/Cargo.toml b/crates/polars-compute/Cargo.toml index 14be8be65f804..4ade7134ec5e3 100644 --- a/crates/polars-compute/Cargo.toml +++ b/crates/polars-compute/Cargo.toml @@ -17,10 +17,13 @@ polars-error = { workspace = true } polars-utils = { workspace = true } strength_reduce = { workspace = true } +[dev-dependencies] +rand = { workspace = true } + [build-dependencies] version_check = { workspace = true } [features] nightly = [] -simd = [] +simd = ["arrow/simd"] dtype-array = [] diff --git a/crates/polars-compute/src/arithmetic/signed.rs b/crates/polars-compute/src/arithmetic/signed.rs index 6e500ecdb5c94..94a13ba10394e 100644 --- a/crates/polars-compute/src/arithmetic/signed.rs +++ b/crates/polars-compute/src/arithmetic/signed.rs @@ -1,6 +1,6 @@ use arrow::array::{PrimitiveArray as PArr, StaticArray}; use arrow::compute::utils::{combine_validities_and, combine_validities_and3}; -use polars_utils::signed_divmod::SignedDivMod; +use polars_utils::floor_divmod::FloorDivMod; use strength_reduce::*; use super::PrimitiveArithmeticKernelImpl; @@ -35,7 +35,8 @@ macro_rules! impl_signed_arith_kernel { other.take_validity().as_ref(), // compute combination twice. Some(&mask), ); - let ret = prim_binary_values(lhs, other, |lhs, rhs| lhs.wrapping_div_mod(rhs).0); + let ret = + prim_binary_values(lhs, other, |lhs, rhs| lhs.wrapping_floor_div_mod(rhs).0); ret.with_validity(valid) } @@ -63,7 +64,8 @@ macro_rules! impl_signed_arith_kernel { other.take_validity().as_ref(), // compute combination twice. Some(&mask), ); - let ret = prim_binary_values(lhs, other, |lhs, rhs| lhs.wrapping_div_mod(rhs).1); + let ret = + prim_binary_values(lhs, other, |lhs, rhs| lhs.wrapping_floor_div_mod(rhs).1); ret.with_validity(valid) } @@ -133,7 +135,7 @@ macro_rules! impl_signed_arith_kernel { let mask = rhs.tot_ne_kernel_broadcast(&0); let valid = combine_validities_and(rhs.validity(), Some(&mask)); - let ret = prim_unary_values(rhs, |x| lhs.wrapping_div_mod(x).0); + let ret = prim_unary_values(rhs, |x| lhs.wrapping_floor_div_mod(x).0); ret.with_validity(valid) } @@ -205,7 +207,7 @@ macro_rules! impl_signed_arith_kernel { let mask = rhs.tot_ne_kernel_broadcast(&0); let valid = combine_validities_and(rhs.validity(), Some(&mask)); - let ret = prim_unary_values(rhs, |x| lhs.wrapping_div_mod(x).1); + let ret = prim_unary_values(rhs, |x| lhs.wrapping_floor_div_mod(x).1); ret.with_validity(valid) } diff --git a/crates/polars-compute/src/arity.rs b/crates/polars-compute/src/arity.rs index 8fec0d3a513cf..33c8b0eb05843 100644 --- a/crates/polars-compute/src/arity.rs +++ b/crates/polars-compute/src/arity.rs @@ -8,7 +8,7 @@ use arrow::types::NativeType; /// /// # Safety /// - arr must point to a readable slice of length len. -/// - out must point to a writeable slice of length len. +/// - out must point to a writable slice of length len. #[inline(never)] unsafe fn ptr_apply_unary_kernel O>( arr: *const I, @@ -25,7 +25,7 @@ unsafe fn ptr_apply_unary_kernel O>( /// # Safety /// - left must point to a readable slice of length len. /// - right must point to a readable slice of length len. -/// - out must point to a writeable slice of length len. +/// - out must point to a writable slice of length len. #[inline(never)] unsafe fn ptr_apply_binary_kernel O>( left: *const L, diff --git a/crates/polars-compute/src/filter/avx512.rs b/crates/polars-compute/src/filter/avx512.rs new file mode 100644 index 0000000000000..11466b137be8d --- /dev/null +++ b/crates/polars-compute/src/filter/avx512.rs @@ -0,0 +1,114 @@ +use core::arch::x86_64::*; + +// It's not possible to inline target_feature(enable = ...) functions into other +// functions without that enabled, so we use a macro for these very-similarly +// structured functions. +macro_rules! simd_filter { + ($values: ident, $mask_bytes: ident, $out: ident, |$subchunk: ident, $m: ident: $MaskT: ty| $body:block) => {{ + const MASK_BITS: usize = std::mem::size_of::<$MaskT>() * 8; + + // Do a 64-element loop for sparse fast path. + let chunks = $values.chunks_exact(64); + $values = chunks.remainder(); + for chunk in chunks { + let mask_chunk = $mask_bytes.get_unchecked(..8); + $mask_bytes = $mask_bytes.get_unchecked(8..); + let mut m64 = u64::from_le_bytes(mask_chunk.try_into().unwrap()); + + // Fast-path: skip entire 64-element chunk. + if m64 == 0 { + continue; + } + + for $subchunk in chunk.chunks_exact(MASK_BITS) { + let $m = m64 as $MaskT; + $body; + m64 >>= MASK_BITS % 64; + } + } + + // Handle the SIMD-block-sized remainder. + let subchunks = $values.chunks_exact(MASK_BITS); + $values = subchunks.remainder(); + for $subchunk in subchunks { + let mask_chunk = $mask_bytes.get_unchecked(..MASK_BITS / 8); + $mask_bytes = $mask_bytes.get_unchecked(MASK_BITS / 8..); + let $m = <$MaskT>::from_le_bytes(mask_chunk.try_into().unwrap()); + $body; + } + + ($values, $mask_bytes, $out) + }}; +} + +/// # Safety +/// out must be valid for 64 + bitslice(mask_bytes, 0..values.len()).count_ones() writes. +/// AVX512_VBMI2 must be enabled. +#[target_feature(enable = "avx512f")] +#[target_feature(enable = "avx512vbmi2")] +pub unsafe fn filter_u8_avx512vbmi2<'a>( + mut values: &'a [u8], + mut mask_bytes: &'a [u8], + mut out: *mut u8, +) -> (&'a [u8], &'a [u8], *mut u8) { + simd_filter!(values, mask_bytes, out, |vchunk, m: u64| { + // We don't use compress-store instructions because they are very slow + // on Zen. We are allowed to overshoot anyway. + let v = _mm512_loadu_si512(vchunk.as_ptr().cast()); + let filtered = _mm512_maskz_compress_epi8(m, v); + _mm512_storeu_si512(out.cast(), filtered); + out = out.add(m.count_ones() as usize); + }) +} + +/// # Safety +/// out must be valid for 32 + bitslice(mask_bytes, 0..values.len()).count_ones() writes. +/// AVX512_VBMI2 must be enabled. +#[target_feature(enable = "avx512f")] +#[target_feature(enable = "avx512vbmi2")] +pub unsafe fn filter_u16_avx512vbmi2<'a>( + mut values: &'a [u16], + mut mask_bytes: &'a [u8], + mut out: *mut u16, +) -> (&'a [u16], &'a [u8], *mut u16) { + simd_filter!(values, mask_bytes, out, |vchunk, m: u32| { + let v = _mm512_loadu_si512(vchunk.as_ptr().cast()); + let filtered = _mm512_maskz_compress_epi16(m, v); + _mm512_storeu_si512(out.cast(), filtered); + out = out.add(m.count_ones() as usize); + }) +} + +/// # Safety +/// out must be valid for 16 + bitslice(mask_bytes, 0..values.len()).count_ones() writes. +/// AVX512F must be enabled. +#[target_feature(enable = "avx512f")] +pub unsafe fn filter_u32_avx512f<'a>( + mut values: &'a [u32], + mut mask_bytes: &'a [u8], + mut out: *mut u32, +) -> (&'a [u32], &'a [u8], *mut u32) { + simd_filter!(values, mask_bytes, out, |vchunk, m: u16| { + let v = _mm512_loadu_si512(vchunk.as_ptr().cast()); + let filtered = _mm512_maskz_compress_epi32(m, v); + _mm512_storeu_si512(out.cast(), filtered); + out = out.add(m.count_ones() as usize); + }) +} + +/// # Safety +/// out must be valid for 8 + bitslice(mask_bytes, 0..values.len()).count_ones() writes. +/// AVX512F must be enabled. +#[target_feature(enable = "avx512f")] +pub unsafe fn filter_u64_avx512f<'a>( + mut values: &'a [u64], + mut mask_bytes: &'a [u8], + mut out: *mut u64, +) -> (&'a [u64], &'a [u8], *mut u64) { + simd_filter!(values, mask_bytes, out, |vchunk, m: u8| { + let v = _mm512_loadu_si512(vchunk.as_ptr().cast()); + let filtered = _mm512_maskz_compress_epi64(m, v); + _mm512_storeu_si512(out.cast(), filtered); + out = out.add(m.count_ones() as usize); + }) +} diff --git a/crates/polars-compute/src/filter/boolean.rs b/crates/polars-compute/src/filter/boolean.rs index 4030d819a59b6..7ff426416c6d9 100644 --- a/crates/polars-compute/src/filter/boolean.rs +++ b/crates/polars-compute/src/filter/boolean.rs @@ -1,160 +1,321 @@ -use super::*; +use arrow::bitmap::Bitmap; +use polars_utils::clmul::prefix_xorsum; +use polars_utils::slice::load_padded_le_u64; -pub(super) fn filter_bitmap_and_validity( - values: &Bitmap, - validity: Option<&Bitmap>, - mask: &Bitmap, -) -> (MutableBitmap, Option) { - if let Some(validity) = validity { - let (values, validity) = null_filter(values, validity, mask); - (values, Some(validity)) - } else { - (nonnull_filter(values, mask), None) +const U56_MAX: u64 = (1 << 56) - 1; + +fn pext64_polyfill(mut v: u64, mut m: u64, m_popcnt: u32) -> u64 { + // Fast path: popcount is low. + if m_popcnt <= 4 { + // Not a "while m != 0" but a for loop instead so the compiler fully + // unrolls the loop, this makes bit << i much faster. + let mut out = 0; + for i in 0..4 { + if m == 0 { + break; + }; + + let bit = (v >> m.trailing_zeros()) & 1; + out |= bit << i; + m &= m.wrapping_sub(1); // Clear least significant bit. + } + return out; + } + + // Fast path: all the masked bits in v are 0 or 1. + // Despite this fast path being simpler than the above popcount-based one, + // we do it afterwards because if m has a low popcount these branches become + // very unpredictable. + v &= m; + if v == 0 { + return 0; + } else if v == m { + return (1 << m_popcnt) - 1; + } + + // This algorithm is too involved to explain here, see https://github.com/zwegner/zp7. + // That is an optimized version of Hacker's Delight Chapter 7-4, parallel suffix method for compress(). + let mut invm = !m; + + for i in 0..6 { + let shift = 1 << i; + let prefix_count_bit = if i < 5 { + prefix_xorsum(invm) + } else { + invm.wrapping_neg() << 1 + }; + let keep_in_place = v & !prefix_count_bit; + let shift_down = v & prefix_count_bit; + v = keep_in_place | (shift_down >> shift); + invm &= prefix_count_bit; + } + v +} + +pub fn filter_boolean_kernel(values: &Bitmap, mask: &Bitmap) -> Bitmap { + assert_eq!(values.len(), mask.len()); + let mask_bits_set = mask.set_bits(); + + // Fast path: values is all-0s or all-1s. + if let Some(num_values_bits) = values.lazy_set_bits() { + if num_values_bits == 0 || num_values_bits == values.len() { + return Bitmap::new_with_value(num_values_bits == values.len(), mask_bits_set); + } } + + // Fast path: mask is all-0s or all-1s. + if mask_bits_set == 0 { + return Bitmap::new(); + } else if mask_bits_set == mask.len() { + return values.clone(); + } + + // Overallocate by 1 u64 so we can always do a full u64 write. + let num_words = mask_bits_set.div_ceil(64); + let num_bytes = 8 * (num_words + 1); + let mut out_vec: Vec = Vec::with_capacity(num_bytes); + + unsafe { + if mask_bits_set <= mask.len() / (64 * 4) { + // Less than one in 1 in 4 words has a bit set on average, use sparse kernel. + filter_boolean_kernel_sparse(values, mask, out_vec.as_mut_ptr()); + } else if polars_utils::cpuid::has_fast_bmi2() { + #[cfg(target_arch = "x86_64")] + filter_boolean_kernel_pext::(values, mask, out_vec.as_mut_ptr(), |v, m, _| { + // SAFETY: has_fast_bmi2 ensures this is a legal instruction. + core::arch::x86_64::_pext_u64(v, m) + }); + } else { + filter_boolean_kernel_pext::( + values, + mask, + out_vec.as_mut_ptr(), + pext64_polyfill, + ) + } + + // SAFETY: the above filters must have initialized these bytes. + out_vec.set_len(mask_bits_set.div_ceil(8)); + } + + Bitmap::from_u8_vec(out_vec, mask_bits_set) } /// # Safety -/// This assumes that the `mask_chunks` contains a number of set/true items equal -/// to `filter_count` -unsafe fn nonnull_filter_impl( - values: &Bitmap, - mut mask_chunks: I, - filter_count: usize, -) -> MutableBitmap -where - I: BitChunkIterExact, -{ - // TODO! we might use ChunksExact here if offset = 0. - let mut chunks = values.chunks::(); - let mut new = MutableBitmap::with_capacity(filter_count); - - chunks - .by_ref() - .zip(mask_chunks.by_ref()) - .for_each(|(chunk, mask_chunk)| { - let ones = mask_chunk.count_ones(); - let leading_ones = get_leading_ones(mask_chunk); - - if ones == leading_ones { - let size = leading_ones as usize; - unsafe { new.extend_from_slice_unchecked(chunk.to_ne_bytes().as_ref(), 0, size) }; - return; - } +/// out_ptr must point to a buffer of length >= 8 + 8 * ceil(mask.set_bits() / 64). +/// This function will initialize at least the first ceil(mask.set_bits() / 8) bytes. +unsafe fn filter_boolean_kernel_sparse(values: &Bitmap, mask: &Bitmap, mut out_ptr: *mut u8) { + assert_eq!(values.len(), mask.len()); + + let mut value_idx = 0; + let mut bits_in_word = 0usize; + let mut word = 0u64; + let (mut mask_bytes, offset, len) = mask.as_slice(); + if len == 0 { + return; + } + + // Handle offset. + if offset > 0 { + let first_byte = mask_bytes[0]; + mask_bytes = &mask_bytes[1..]; - let ones_iter = BitChunkOnes::from_known_count(mask_chunk, ones as usize); - for pos in ones_iter { - new.push_unchecked(chunk & (1 << pos) > 0); + for byte_idx in offset..8 { + let mask_bit = first_byte & (1 << byte_idx) != 0; + if mask_bit && value_idx < len { + let bit = unsafe { values.get_bit_unchecked(value_idx) }; + word |= (bit as u64) << bits_in_word; + bits_in_word += 1; } - }); + value_idx += 1; + } + } - chunks - .remainder_iter() - .zip(mask_chunks.remainder_iter()) - .for_each(|(value, is_selected)| { - if is_selected { - unsafe { - new.push_unchecked(value); - }; + macro_rules! loop_body { + ($m: expr) => {{ + let mut m = $m; + while m > 0 { + let idx_in_m = m.trailing_zeros() as usize; + let bit = unsafe { values.get_bit_unchecked(value_idx + idx_in_m) }; + word |= (bit as u64) << bits_in_word; + bits_in_word += 1; + + if bits_in_word == 64 { + unsafe { + out_ptr.cast::().write_unaligned(word.to_le()); + out_ptr = out_ptr.add(8); + bits_in_word = 0; + word = 0; + } + } + + m &= m.wrapping_sub(1); // Clear least significant bit. } - }); + }}; + } - new + // Handle bulk. + while value_idx + 64 <= len { + let chunk; + unsafe { + // SAFETY: we checked that value and mask have same length. + chunk = mask_bytes.get_unchecked(0..8); + mask_bytes = mask_bytes.get_unchecked(8..); + }; + let m = u64::from_le_bytes(chunk.try_into().unwrap()); + loop_body!(m); + value_idx += 64; + } + + // Handle remainder. + if value_idx < len { + let rest_len = len - value_idx; + assert!(rest_len < 64); + let m = load_padded_le_u64(mask_bytes) & ((1 << rest_len) - 1); + loop_body!(m); + } + + if bits_in_word > 0 { + unsafe { + out_ptr.cast::().write_unaligned(word.to_le()); + } + } } /// # Safety -/// This assumes that the `mask_chunks` contains a number of set/true items equal -/// to `filter_count` -unsafe fn null_filter_impl( +/// See filter_boolean_kernel_sparse. +unsafe fn filter_boolean_kernel_pext u64>( values: &Bitmap, - validity: &Bitmap, - mut mask_chunks: I, - filter_count: usize, -) -> (MutableBitmap, MutableBitmap) -where - I: BitChunkIterExact, -{ - let mut chunks = values.chunks::(); - let mut validity_chunks = validity.chunks::(); - - let mut new = MutableBitmap::with_capacity(filter_count); - let mut new_validity = MutableBitmap::with_capacity(filter_count); - - chunks - .by_ref() - .zip(validity_chunks.by_ref()) - .zip(mask_chunks.by_ref()) - .for_each(|((chunk, validity_chunk), mask_chunk)| { - let ones = mask_chunk.count_ones(); - let leading_ones = get_leading_ones(mask_chunk); - - if ones == leading_ones { - let size = leading_ones as usize; + mask: &Bitmap, + mut out_ptr: *mut u8, + pext: F, +) { + assert_eq!(values.len(), mask.len()); + let mut bits_in_word = 0usize; + let mut word = 0u64; + + macro_rules! loop_body { + ($v: expr, $m: expr) => {{ + let (v, m) = ($v, $m); + + // Fast-path, all-0 mask. + if m == 0 { + continue; + } + + // Fast path, all-1 mask. + // This is only worth it if we don't have a native pext. + if !HAS_NATIVE_PEXT && m == U56_MAX { + word |= v << bits_in_word; unsafe { - new.extend_from_slice_unchecked(chunk.to_ne_bytes().as_ref(), 0, size); - - // SAFETY: invariant offset + length <= slice.len() - new_validity.extend_from_slice_unchecked( - validity_chunk.to_ne_bytes().as_ref(), - 0, - size, - ); + out_ptr.cast::().write_unaligned(word.to_le()); + out_ptr = out_ptr.add(7); } - return; + word >>= 56; + continue; } - // this triggers a bitcount - let ones_iter = BitChunkOnes::from_known_count(mask_chunk, ones as usize); - for pos in ones_iter { - new.push_unchecked(chunk & (1 << pos) > 0); - new_validity.push_unchecked(validity_chunk & (1 << pos) > 0); - } - }); - - chunks - .remainder_iter() - .zip(validity_chunks.remainder_iter()) - .zip(mask_chunks.remainder_iter()) - .for_each(|((value, is_valid), is_selected)| { - if is_selected { - unsafe { - new.push_unchecked(value); - new_validity.push_unchecked(is_valid); - }; + let mask_popcnt = m.count_ones(); + let bits = pext(v, m, mask_popcnt); + + // Because we keep bits_in_word < 8 and we iterate over u56s, + // this never loses output bits. + word |= bits << bits_in_word; + bits_in_word += mask_popcnt as usize; + unsafe { + out_ptr.cast::().write_unaligned(word.to_le()); + + let full_bytes_written = bits_in_word / 8; + out_ptr = out_ptr.add(full_bytes_written); + word >>= full_bytes_written * 8; + bits_in_word %= 8; } - }); + }}; + } - (new, new_validity) + let mut v_iter = values.fast_iter_u56(); + let mut m_iter = mask.fast_iter_u56(); + for v in &mut v_iter { + // SAFETY: we checked values and mask have same length. + let m = unsafe { m_iter.next().unwrap_unchecked() }; + loop_body!(v, m); + } + let mut v_rem = v_iter.remainder().0; + let mut m_rem = m_iter.remainder().0; + while m_rem != 0 { + let v = v_rem & U56_MAX; + let m = m_rem & U56_MAX; + v_rem >>= 56; + m_rem >>= 56; + loop_body!(v, m); // Careful, contains 'continue', increment loop variables first. + } } -fn null_filter( +pub fn filter_bitmap_and_validity( values: &Bitmap, - validity: &Bitmap, + validity: Option<&Bitmap>, mask: &Bitmap, -) -> (MutableBitmap, MutableBitmap) { - assert_eq!(values.len(), mask.len()); - let filter_count = mask.len() - mask.unset_bits(); - - let (slice, offset, length) = mask.as_slice(); - if offset == 0 { - let mask_chunks = BitChunksExact::::new(slice, length); - unsafe { null_filter_impl(values, validity, mask_chunks, filter_count) } +) -> (Bitmap, Option) { + let filtered_values = filter_boolean_kernel(values, mask); + if let Some(validity) = validity { + // TODO: we could theoretically be faster by computing these two filters + // at once. Unsure if worth duplicating all the code above. + let filtered_validity = filter_boolean_kernel(validity, mask); + (filtered_values, Some(filtered_validity)) } else { - let mask_chunks = mask.chunks::(); - unsafe { null_filter_impl(values, validity, mask_chunks, filter_count) } + (filtered_values, None) } } -fn nonnull_filter(values: &Bitmap, mask: &Bitmap) -> MutableBitmap { - assert_eq!(values.len(), mask.len()); - let filter_count = mask.len() - mask.unset_bits(); +#[cfg(test)] +mod test { + use rand::prelude::*; - let (slice, offset, length) = mask.as_slice(); - if offset == 0 { - let mask_chunks = BitChunksExact::::new(slice, length); - unsafe { nonnull_filter_impl(values, mask_chunks, filter_count) } - } else { - let mask_chunks = mask.chunks::(); - unsafe { nonnull_filter_impl(values, mask_chunks, filter_count) } + use super::*; + + fn naive_pext64(word: u64, mask: u64) -> u64 { + let mut out = 0; + let mut out_idx = 0; + + for i in 0..64 { + let ith_mask_bit = (mask >> i) & 1; + let ith_word_bit = (word >> i) & 1; + if ith_mask_bit == 1 { + out |= ith_word_bit << out_idx; + out_idx += 1; + } + } + + out + } + + #[test] + fn test_pext64() { + // Verify polyfill against naive implementation. + let mut rng = StdRng::seed_from_u64(0xdeadbeef); + for _ in 0..100 { + let x = rng.gen(); + let y = rng.gen(); + assert_eq!(naive_pext64(x, y), pext64_polyfill(x, y, y.count_ones())); + + // Test all-zeros and all-ones. + assert_eq!(naive_pext64(0, y), pext64_polyfill(0, y, y.count_ones())); + assert_eq!( + naive_pext64(u64::MAX, y), + pext64_polyfill(u64::MAX, y, y.count_ones()) + ); + assert_eq!(naive_pext64(x, 0), pext64_polyfill(x, 0, 0)); + assert_eq!(naive_pext64(x, u64::MAX), pext64_polyfill(x, u64::MAX, 64)); + + // Test low popcount mask. + let popcnt = rng.gen_range(0..=8); + // Not perfect (can generate same bit twice) but it'll do. + let mask = (0..popcnt).map(|_| 1 << rng.gen_range(0..64)).sum(); + assert_eq!( + naive_pext64(x, mask), + pext64_polyfill(x, mask, mask.count_ones()) + ); + } } } diff --git a/crates/polars-compute/src/filter/mod.rs b/crates/polars-compute/src/filter/mod.rs index ed6cdd12636ec..38cce5c103d29 100644 --- a/crates/polars-compute/src/filter/mod.rs +++ b/crates/polars-compute/src/filter/mod.rs @@ -1,34 +1,22 @@ //! Contains operators to filter arrays such as [`filter`]. mod boolean; mod primitive; +mod scalar; + +#[cfg(all(target_arch = "x86_64", feature = "simd"))] +mod avx512; use arrow::array::growable::make_growable; -use arrow::array::*; -use arrow::bitmap::utils::{BitChunkIterExact, BitChunksExact, SlicesIterator}; -use arrow::bitmap::{Bitmap, MutableBitmap}; +use arrow::array::{new_empty_array, Array, BinaryViewArray, BooleanArray, PrimitiveArray}; +use arrow::bitmap::utils::SlicesIterator; use arrow::datatypes::ArrowDataType; -use arrow::types::simd::Simd; -use arrow::types::{BitChunkOnes, NativeType}; use arrow::with_match_primitive_type_full; -use boolean::*; -use polars_error::*; -use primitive::*; - -/// Function that can filter arbitrary arrays -pub type Filter<'a> = Box Box + 'a + Send + Sync>; - -#[inline] -fn get_leading_ones(chunk: u64) -> u32 { - if cfg!(target_endian = "little") { - chunk.trailing_ones() - } else { - chunk.leading_ones() - } -} +use polars_error::PolarsResult; pub fn filter(array: &dyn Array, mask: &BooleanArray) -> PolarsResult> { - // The validities may be masking out `true` bits, making the filter operation - // based on the values incorrect + assert_eq!(array.len(), mask.len()); + + // Treat null mask values as false. if let Some(validities) = mask.validity() { let values = mask.values(); let new_values = values & validities; @@ -36,46 +24,43 @@ pub fn filter(array: &dyn Array, mask: &BooleanArray) -> PolarsResult with_match_primitive_type_full!(primitive, |$T| { - let array = array.as_any().downcast_ref().unwrap(); - Ok(Box::new(filter_primitive::<$T>(array, mask.values()))) + let array: &PrimitiveArray<$T> = array.as_any().downcast_ref().unwrap(); + let (values, validity) = primitive::filter_values_and_validity::<$T>(array.values(), array.validity(), mask.values()); + Ok(Box::new(PrimitiveArray::from_vec(values).with_validity(validity))) }), Boolean => { let array = array.as_any().downcast_ref::().unwrap(); - let (values, validity) = - filter_bitmap_and_validity(array.values(), array.validity(), mask.values()); - Ok(BooleanArray::new( - array.data_type().clone(), - values.freeze(), - validity.map(|v| v.freeze()), - ) - .boxed()) + let (values, validity) = boolean::filter_bitmap_and_validity( + array.values(), + array.validity(), + mask.values(), + ); + Ok(BooleanArray::new(array.data_type().clone(), values, validity).boxed()) }, BinaryView => { let array = array.as_any().downcast_ref::().unwrap(); let views = array.views(); let validity = array.validity(); - // TODO! we might opt for a filter that maintains the bytes_count - // currently we don't do that and bytes_len is set to UNKNOWN. - let (views, validity) = filter_values_and_validity(views, validity, mask.values()); + let (views, validity) = + primitive::filter_values_and_validity(views, validity, mask.values()); Ok(unsafe { BinaryViewArray::new_unchecked_unknown_md( array.data_type().clone(), views.into(), array.data_buffers().clone(), - validity.map(|v| v.freeze()), + validity, Some(array.total_buffer_len()), ) } diff --git a/crates/polars-compute/src/filter/primitive.rs b/crates/polars-compute/src/filter/primitive.rs index 085074efbe715..10c00afdff1cd 100644 --- a/crates/polars-compute/src/filter/primitive.rs +++ b/crates/polars-compute/src/filter/primitive.rs @@ -1,183 +1,92 @@ -use super::*; - -pub(super) fn filter_values_and_validity( - values: &[T], - validity: Option<&Bitmap>, - mask: &Bitmap, -) -> (Vec, Option) { - if let Some(validity) = validity { - let (values, validity) = null_filter(values, validity, mask); - (values, Some(validity)) - } else { - (nonnull_filter(values, mask), None) - } +use arrow::bitmap::Bitmap; +use bytemuck::{cast_slice, cast_vec, Pod}; + +#[cfg(all(target_arch = "x86_64", feature = "simd"))] +use super::avx512; +use super::boolean::filter_boolean_kernel; +use super::scalar::{scalar_filter, scalar_filter_offset}; + +type FilterFn = for<'a> unsafe fn(&'a [T], &'a [u8], *mut T) -> (&'a [T], &'a [u8], *mut T); + +fn nop_filter<'a, T: Pod>( + values: &'a [T], + mask: &'a [u8], + out: *mut T, +) -> (&'a [T], &'a [u8], *mut T) { + (values, mask, out) } -pub(super) fn filter_primitive( - array: &PrimitiveArray, - mask: &Bitmap, -) -> PrimitiveArray { - assert_eq!(array.len(), mask.len()); - let (values, validity) = filter_values_and_validity(array.values(), array.validity(), mask); - let validity = validity.map(|validity| validity.freeze()); - unsafe { - PrimitiveArray::::new_unchecked(array.data_type().clone(), values.into(), validity) +pub fn filter_values(values: &[T], mask: &Bitmap) -> Vec { + match (std::mem::size_of::(), std::mem::align_of::()) { + (1, 1) => cast_vec(filter_values_u8(cast_slice(values), mask)), + (2, 2) => cast_vec(filter_values_u16(cast_slice(values), mask)), + (4, 4) => cast_vec(filter_values_u32(cast_slice(values), mask)), + (8, 8) => cast_vec(filter_values_u64(cast_slice(values), mask)), + _ => filter_values_generic(values, mask, 1, nop_filter), } } -/// # Safety -/// This assumes that the `mask_chunks` contains a number of set/true items equal -/// to `filter_count` -unsafe fn nonnull_filter_impl(values: &[T], mut mask_chunks: I, filter_count: usize) -> Vec -where - T: NativeType, - I: BitChunkIterExact, -{ - let mut chunks = values.chunks_exact(64); - let mut new = Vec::::with_capacity(filter_count); - let mut dst = new.as_mut_ptr(); - - chunks - .by_ref() - .zip(mask_chunks.by_ref()) - .for_each(|(chunk, mask_chunk)| { - let ones = mask_chunk.count_ones(); - let leading_ones = get_leading_ones(mask_chunk); - - if ones == leading_ones { - let size = leading_ones as usize; - unsafe { - std::ptr::copy(chunk.as_ptr(), dst, size); - dst = dst.add(size); - } - return; - } - - let ones_iter = BitChunkOnes::from_known_count(mask_chunk, ones as usize); - for pos in ones_iter { - dst.write(*chunk.get_unchecked(pos)); - dst = dst.add(1); - } - }); - - chunks - .remainder() - .iter() - .zip(mask_chunks.remainder_iter()) - .for_each(|(value, b)| { - if b { - unsafe { - dst.write(*value); - dst = dst.add(1); - }; - } - }); +fn filter_values_u8(values: &[u8], mask: &Bitmap) -> Vec { + #[cfg(all(target_arch = "x86_64", feature = "simd"))] + if std::arch::is_x86_feature_detected!("avx512vbmi2") { + return filter_values_generic(values, mask, 64, avx512::filter_u8_avx512vbmi2); + } - unsafe { new.set_len(filter_count) }; - new + filter_values_generic(values, mask, 1, nop_filter) } -/// # Safety -/// This assumes that the `mask_chunks` contains a number of set/true items equal -/// to `filter_count` -unsafe fn null_filter_impl( - values: &[T], - validity: &Bitmap, - mut mask_chunks: I, - filter_count: usize, -) -> (Vec, MutableBitmap) -where - T: NativeType, - I: BitChunkIterExact, -{ - let mut chunks = values.chunks_exact(64); - - let mut validity_chunks = validity.chunks::(); - - let mut new = Vec::::with_capacity(filter_count); - let mut dst = new.as_mut_ptr(); - let mut new_validity = MutableBitmap::with_capacity(filter_count); - - chunks - .by_ref() - .zip(validity_chunks.by_ref()) - .zip(mask_chunks.by_ref()) - .for_each(|((chunk, validity_chunk), mask_chunk)| { - let ones = mask_chunk.count_ones(); - let leading_ones = get_leading_ones(mask_chunk); +fn filter_values_u16(values: &[u16], mask: &Bitmap) -> Vec { + #[cfg(all(target_arch = "x86_64", feature = "simd"))] + if std::arch::is_x86_feature_detected!("avx512vbmi2") { + return filter_values_generic(values, mask, 32, avx512::filter_u16_avx512vbmi2); + } - if ones == leading_ones { - let size = leading_ones as usize; - unsafe { - std::ptr::copy(chunk.as_ptr(), dst, size); - dst = dst.add(size); + filter_values_generic(values, mask, 1, nop_filter) +} - // SAFETY: invariant offset + length <= slice.len() - new_validity.extend_from_slice_unchecked( - validity_chunk.to_ne_bytes().as_ref(), - 0, - size, - ); - } - return; - } +fn filter_values_u32(values: &[u32], mask: &Bitmap) -> Vec { + #[cfg(all(target_arch = "x86_64", feature = "simd"))] + if std::arch::is_x86_feature_detected!("avx512f") { + return filter_values_generic(values, mask, 16, avx512::filter_u32_avx512f); + } - // this triggers a bitcount - let ones_iter = BitChunkOnes::from_known_count(mask_chunk, ones as usize); - for pos in ones_iter { - dst.write(*chunk.get_unchecked(pos)); - dst = dst.add(1); - new_validity.push_unchecked(validity_chunk & (1 << pos) > 0); - } - }); + filter_values_generic(values, mask, 1, nop_filter) +} - chunks - .remainder() - .iter() - .zip(validity_chunks.remainder_iter()) - .zip(mask_chunks.remainder_iter()) - .for_each(|((value, is_valid), is_selected)| { - if is_selected { - unsafe { - dst.write(*value); - dst = dst.add(1); - new_validity.push_unchecked(is_valid); - }; - } - }); +fn filter_values_u64(values: &[u64], mask: &Bitmap) -> Vec { + #[cfg(all(target_arch = "x86_64", feature = "simd"))] + if std::arch::is_x86_feature_detected!("avx512f") { + return filter_values_generic(values, mask, 8, avx512::filter_u64_avx512f); + } - unsafe { new.set_len(filter_count) }; - (new, new_validity) + filter_values_generic(values, mask, 1, nop_filter) } -fn null_filter( +fn filter_values_generic( values: &[T], - validity: &Bitmap, mask: &Bitmap, -) -> (Vec, MutableBitmap) { + pad: usize, + bulk_filter: FilterFn, +) -> Vec { assert_eq!(values.len(), mask.len()); - let filter_count = mask.len() - mask.unset_bits(); - - let (slice, offset, length) = mask.as_slice(); - if offset == 0 { - let mask_chunks = BitChunksExact::::new(slice, length); - unsafe { null_filter_impl(values, validity, mask_chunks, filter_count) } - } else { - let mask_chunks = mask.chunks::(); - unsafe { null_filter_impl(values, validity, mask_chunks, filter_count) } + let mask_bits_set = mask.set_bits(); + let mut out = Vec::with_capacity(mask_bits_set + pad); + unsafe { + let (values, mask_bytes, out_ptr) = scalar_filter_offset(values, mask, out.as_mut_ptr()); + let (values, mask_bytes, out_ptr) = bulk_filter(values, mask_bytes, out_ptr); + scalar_filter(values, mask_bytes, out_ptr); + out.set_len(mask_bits_set); } + out } -fn nonnull_filter(values: &[T], mask: &Bitmap) -> Vec { - assert_eq!(values.len(), mask.len()); - let filter_count = mask.len() - mask.unset_bits(); - - let (slice, offset, length) = mask.as_slice(); - if offset == 0 { - let mask_chunks = BitChunksExact::::new(slice, length); - unsafe { nonnull_filter_impl(values, mask_chunks, filter_count) } - } else { - let mask_chunks = mask.chunks::(); - unsafe { nonnull_filter_impl(values, mask_chunks, filter_count) } - } +pub fn filter_values_and_validity( + values: &[T], + validity: Option<&Bitmap>, + mask: &Bitmap, +) -> (Vec, Option) { + ( + filter_values(values, mask), + validity.map(|v| filter_boolean_kernel(v, mask)), + ) } diff --git a/crates/polars-compute/src/filter/scalar.rs b/crates/polars-compute/src/filter/scalar.rs new file mode 100644 index 0000000000000..e9576bad8a2cb --- /dev/null +++ b/crates/polars-compute/src/filter/scalar.rs @@ -0,0 +1,137 @@ +use arrow::bitmap::Bitmap; +use bytemuck::Pod; +use polars_utils::slice::load_padded_le_u64; + +/// # Safety +/// If the ith bit of m is set (m & (1 << i)), then v[i] must be in-bounds. +/// out must be valid for at least m.count_ones() + 1 writes. +unsafe fn scalar_sparse_filter64(v: &[T], mut m: u64, out: *mut T) { + let mut written = 0usize; + + while m > 0 { + // Unroll loop manually twice. + let idx = m.trailing_zeros() as usize; + *out.add(written) = *v.get_unchecked(idx); + m &= m.wrapping_sub(1); // Clear least significant bit. + written += 1; + + // tz % 64 otherwise we could go out of bounds + let idx = (m.trailing_zeros() % 64) as usize; + *out.add(written) = *v.get_unchecked(idx); + m &= m.wrapping_sub(1); // Clear least significant bit. + written += 1; + } +} + +/// # Safety +/// v.len() >= 64 must hold. +/// out must be valid for at least m.count_ones() + 1 writes. +unsafe fn scalar_dense_filter64(v: &[T], mut m: u64, out: *mut T) { + // Rust generated significantly better code if we write the below loop + // with v as a pointer, and out.add(written) instead of incrementing out + // directly. + let mut written = 0usize; + let mut src = v.as_ptr(); + + // We hope the outer loop doesn't get unrolled, but the inner loop does. + for _ in 0..16 { + for i in 0..4 { + *out.add(written) = *src; + written += ((m >> i) & 1) as usize; + src = src.add(1); + } + m >>= 4; + } +} + +/// Handles the offset portion of a Bitmap to start an efficient filter operation. +/// Returns the remaining values and mask bytes for the filter, as well as where +/// to continue writing to out. +/// +/// # Safety +/// out must be valid for at least mask.set_bits() + 1 writes. +pub unsafe fn scalar_filter_offset<'a, T: Pod>( + values: &'a [T], + mask: &'a Bitmap, + mut out: *mut T, +) -> (&'a [T], &'a [u8], *mut T) { + assert_eq!(values.len(), mask.len()); + + let (mut mask_bytes, offset, len) = mask.as_slice(); + let mut value_idx = 0; + if offset > 0 { + let first_byte = mask_bytes[0]; + mask_bytes = &mask_bytes[1..]; + + for byte_idx in offset..8 { + if value_idx < len { + unsafe { + // SAFETY: we checked that value_idx < len. + let bit_is_set = first_byte & (1 << byte_idx) != 0; + *out = *values.get_unchecked(value_idx); + out = out.add(bit_is_set as usize); + } + value_idx += 1; + } + } + } + + (&values[value_idx..], mask_bytes, out) +} + +/// # Safety +/// out must be valid for 1 + bitslice(mask_bytes, 0..values.len()).count_ones() writes. +pub unsafe fn scalar_filter(values: &[T], mut mask_bytes: &[u8], mut out: *mut T) { + assert!(mask_bytes.len() * 8 >= values.len()); + + // Handle bulk. + let mut value_idx = 0; + while value_idx + 64 <= values.len() { + let (mask_chunk, value_chunk); + unsafe { + // SAFETY: we checked that value_idx + 64 <= values.len(), so these are + // all in-bounds. + mask_chunk = mask_bytes.get_unchecked(0..8); + mask_bytes = mask_bytes.get_unchecked(8..); + value_chunk = values.get_unchecked(value_idx..value_idx + 64); + value_idx += 64; + }; + let m = u64::from_le_bytes(mask_chunk.try_into().unwrap()); + + // Fast-path: empty mask. + if m == 0 { + continue; + } + + unsafe { + // SAFETY: we will only write at most m_popcnt + 1 to out, which + // is allowed. + + // Fast-path: completely full mask. + if m == u64::MAX { + core::ptr::copy_nonoverlapping(value_chunk.as_ptr(), out, 64); + out = out.add(64); + continue; + } + + let m_popcnt = m.count_ones(); + if m_popcnt <= 16 { + scalar_sparse_filter64(value_chunk, m, out) + } else { + scalar_dense_filter64(value_chunk, m, out) + }; + out = out.add(m_popcnt as usize); + } + } + + // Handle remainder. + if value_idx < values.len() { + let rest_len = values.len() - value_idx; + assert!(rest_len < 64); + let m = load_padded_le_u64(mask_bytes) & ((1 << rest_len) - 1); + unsafe { + let value_chunk = values.get_unchecked(value_idx..); + scalar_sparse_filter64(value_chunk, m, out); + } + } +} diff --git a/crates/polars-compute/src/lib.rs b/crates/polars-compute/src/lib.rs index 0cd894d38013c..797e8a8af9b04 100644 --- a/crates/polars-compute/src/lib.rs +++ b/crates/polars-compute/src/lib.rs @@ -1,4 +1,9 @@ #![cfg_attr(feature = "simd", feature(portable_simd))] +#![cfg_attr(feature = "simd", feature(avx512_target_feature))] +#![cfg_attr( + all(feature = "simd", target_arch = "x86_64"), + feature(stdarch_x86_avx512) +)] pub mod arithmetic; pub mod comparisons; diff --git a/crates/polars-compute/src/min_max/simd.rs b/crates/polars-compute/src/min_max/simd.rs index 887a0a783624d..e72df453c54d4 100644 --- a/crates/polars-compute/src/min_max/simd.rs +++ b/crates/polars-compute/src/min_max/simd.rs @@ -135,22 +135,16 @@ macro_rules! impl_min_max_kernel_float { type Scalar<'a> = $T; fn min_ignore_nan_kernel(&self) -> Option> { - fold_agg_kernel::<$N, $T, _>( - self.values(), - self.validity(), - <$T>::INFINITY, - |a, b| a.simd_min(b), - ) + fold_agg_kernel::<$N, $T, _>(self.values(), self.validity(), <$T>::NAN, |a, b| { + a.simd_min(b) + }) .map(|s| s.reduce_min()) } fn max_ignore_nan_kernel(&self) -> Option> { - fold_agg_kernel::<$N, $T, _>( - self.values(), - self.validity(), - <$T>::NEG_INFINITY, - |a, b| a.simd_max(b), - ) + fold_agg_kernel::<$N, $T, _>(self.values(), self.validity(), <$T>::NAN, |a, b| { + a.simd_max(b) + }) .map(|s| s.reduce_max()) } @@ -179,12 +173,12 @@ macro_rules! impl_min_max_kernel_float { type Scalar<'a> = $T; fn min_ignore_nan_kernel(&self) -> Option> { - fold_agg_kernel::<$N, $T, _>(self, None, <$T>::INFINITY, |a, b| a.simd_min(b)) + fold_agg_kernel::<$N, $T, _>(self, None, <$T>::NAN, |a, b| a.simd_min(b)) .map(|s| s.reduce_min()) } fn max_ignore_nan_kernel(&self) -> Option> { - fold_agg_kernel::<$N, $T, _>(self, None, <$T>::NEG_INFINITY, |a, b| a.simd_max(b)) + fold_agg_kernel::<$N, $T, _>(self, None, <$T>::NAN, |a, b| a.simd_max(b)) .map(|s| s.reduce_max()) } diff --git a/crates/polars-core/Cargo.toml b/crates/polars-core/Cargo.toml index 1915eb71957b3..297caa7a71f3e 100644 --- a/crates/polars-core/Cargo.toml +++ b/crates/polars-core/Cargo.toml @@ -80,7 +80,6 @@ round_series = [] checked_arithmetic = [] is_first_distinct = [] is_last_distinct = [] -asof_join = [] dot_product = [] row_hash = [] reinterpret = [] @@ -133,7 +132,6 @@ docs-selection = [ "checked_arithmetic", "is_first_distinct", "is_last_distinct", - "asof_join", "dot_product", "row_hash", "rolling_window", diff --git a/crates/polars-core/src/chunked_array/arithmetic/decimal.rs b/crates/polars-core/src/chunked_array/arithmetic/decimal.rs index 89efa856db6e9..3e50676566216 100644 --- a/crates/polars-core/src/chunked_array/arithmetic/decimal.rs +++ b/crates/polars-core/src/chunked_array/arithmetic/decimal.rs @@ -1,5 +1,4 @@ use super::*; -use crate::prelude::DecimalChunked; impl Add for &DecimalChunked { type Output = PolarsResult; diff --git a/crates/polars-core/src/chunked_array/arithmetic/mod.rs b/crates/polars-core/src/chunked_array/arithmetic/mod.rs index 0a04f7ae624e6..874ed20330976 100644 --- a/crates/polars-core/src/chunked_array/arithmetic/mod.rs +++ b/crates/polars-core/src/chunked_array/arithmetic/mod.rs @@ -5,7 +5,6 @@ mod numeric; use std::ops::{Add, Div, Mul, Rem, Sub}; -use arrow::array::PrimitiveArray; use arrow::compute::utils::combine_validities_and; use num_traits::{Num, NumCast, ToPrimitive}; pub use numeric::ArithmeticChunked; diff --git a/crates/polars-core/src/chunked_array/array/iterator.rs b/crates/polars-core/src/chunked_array/array/iterator.rs index fcb57e0154179..a9ecbf43ffb8e 100644 --- a/crates/polars-core/src/chunked_array/array/iterator.rs +++ b/crates/polars-core/src/chunked_array/array/iterator.rs @@ -103,6 +103,7 @@ impl ArrayChunked { } /// Apply a closure `F` to each array. + /// /// # Safety /// Return series of `F` must has the same dtype and number of elements as input. #[must_use] @@ -124,6 +125,7 @@ impl ArrayChunked { } /// Zip with a `ChunkedArray` then apply a binary function `F` elementwise. + /// /// # Safety // Return series of `F` must has the same dtype and number of elements as input series. #[must_use] diff --git a/crates/polars-core/src/chunked_array/builder/fixed_size_list.rs b/crates/polars-core/src/chunked_array/builder/fixed_size_list.rs index d2662121c98df..a419ee930401b 100644 --- a/crates/polars-core/src/chunked_array/builder/fixed_size_list.rs +++ b/crates/polars-core/src/chunked_array/builder/fixed_size_list.rs @@ -1,7 +1,3 @@ -use arrow::array::{ - Array, MutableArray, MutableFixedSizeListArray, MutablePrimitiveArray, PrimitiveArray, - PushUnchecked, -}; use arrow::types::NativeType; use polars_utils::unwrap::UnwrapUncheckedRelease; use smartstring::alias::String as SmartString; @@ -16,7 +12,8 @@ pub(crate) struct FixedSizeListNumericBuilder { } impl FixedSizeListNumericBuilder { - /// SAFETY + /// # Safety + /// /// The caller must ensure that the physical numerical type match logical type. pub(crate) unsafe fn new( name: &str, diff --git a/crates/polars-core/src/chunked_array/builder/mod.rs b/crates/polars-core/src/chunked_array/builder/mod.rs index e31fa2968b7c0..fec9b92d86817 100644 --- a/crates/polars-core/src/chunked_array/builder/mod.rs +++ b/crates/polars-core/src/chunked_array/builder/mod.rs @@ -6,7 +6,6 @@ mod null; mod primitive; mod string; -use std::iter::FromIterator; use std::marker::PhantomData; use std::sync::Arc; diff --git a/crates/polars-core/src/chunked_array/cast.rs b/crates/polars-core/src/chunked_array/cast.rs index 29cfeed5012de..976dd568ddc41 100644 --- a/crates/polars-core/src/chunked_array/cast.rs +++ b/crates/polars-core/src/chunked_array/cast.rs @@ -1,10 +1,7 @@ //! Implementations of the ChunkCast Trait. -use std::convert::TryFrom; use arrow::compute::cast::CastOptions; -#[cfg(feature = "dtype-categorical")] -use crate::chunked_array::categorical::CategoricalChunkedBuilder; #[cfg(feature = "timezones")] use crate::chunked_array::temporal::validate_time_zone; #[cfg(feature = "dtype-datetime")] @@ -382,11 +379,12 @@ impl ChunkCast for ListChunked { match data_type { List(child_type) => { match (self.inner_dtype(), &**child_type) { + (old, new) if old == *new => Ok(self.clone().into_series()), #[cfg(feature = "dtype-categorical")] (dt, Categorical(None, _) | Enum(_, _)) if !matches!(dt, Categorical(_, _) | Enum(_, _) | String | Null) => { - polars_bail!(ComputeError: "cannot cast List inner type: '{:?}' to Categorical", dt) + polars_bail!(InvalidOperation: "cannot cast List inner type: '{:?}' to Categorical", dt) }, _ => { // ensure the inner logical type bubbles up @@ -420,7 +418,7 @@ impl ChunkCast for ListChunked { }, _ => { polars_bail!( - ComputeError: "cannot cast List type (inner: '{:?}', to: '{:?}')", + InvalidOperation: "cannot cast List type (inner: '{:?}', to: '{:?}')", self.inner_dtype(), data_type, ) @@ -445,10 +443,16 @@ impl ChunkCast for ArrayChunked { use DataType::*; match data_type { Array(child_type, width) => { + polars_ensure!( + *width == self.width(), + InvalidOperation: "cannot cast Array to a different width" + ); + match (self.inner_dtype(), &**child_type) { + (old, new) if old == *new => Ok(self.clone().into_series()), #[cfg(feature = "dtype-categorical")] (dt, Categorical(None, _) | Enum(_, _)) if !matches!(dt, String) => { - polars_bail!(ComputeError: "cannot cast fixed-size-list inner type: '{:?}' to dtype: {:?}", dt, child_type) + polars_bail!(InvalidOperation: "cannot cast Array inner type: '{:?}' to dtype: {:?}", dt, child_type) }, _ => { // ensure the inner logical type bubbles up @@ -479,7 +483,13 @@ impl ChunkCast for ArrayChunked { )) } }, - _ => polars_bail!(ComputeError: "cannot cast list type"), + _ => { + polars_bail!( + InvalidOperation: "cannot cast Array type (inner: '{:?}', to: '{:?}')", + self.inner_dtype(), + data_type, + ) + }, } } diff --git a/crates/polars-core/src/chunked_array/comparison/categorical.rs b/crates/polars-core/src/chunked_array/comparison/categorical.rs index 6036d1772b4bc..c288929b07cd9 100644 --- a/crates/polars-core/src/chunked_array/comparison/categorical.rs +++ b/crates/polars-core/src/chunked_array/comparison/categorical.rs @@ -168,7 +168,7 @@ where CompareString: Fn(&StringChunked, &'a StringChunked) -> BooleanChunked, { if lhs.is_enum() { - let rhs_cat = rhs.cast(lhs.dtype())?; + let rhs_cat = rhs.clone().into_series().strict_cast(lhs.dtype())?; cat_compare_function(lhs, rhs_cat.categorical().unwrap()) } else if rhs.len() == 1 { match rhs.get(0) { @@ -198,7 +198,7 @@ where CompareString: Fn(&StringChunked, &'a StringChunked) -> BooleanChunked, { if lhs.is_enum() { - let rhs_cat = rhs.cast(lhs.dtype())?; + let rhs_cat = rhs.clone().into_series().strict_cast(lhs.dtype())?; cat_compare_function(lhs, rhs_cat.categorical().unwrap()) } else if rhs.len() == 1 { match rhs.get(0) { diff --git a/crates/polars-core/src/chunked_array/comparison/mod.rs b/crates/polars-core/src/chunked_array/comparison/mod.rs index d64111b48a1f2..57fdfc8c05912 100644 --- a/crates/polars-core/src/chunked_array/comparison/mod.rs +++ b/crates/polars-core/src/chunked_array/comparison/mod.rs @@ -8,7 +8,6 @@ use std::ops::Not; use arrow::array::BooleanArray; use arrow::bitmap::MutableBitmap; use arrow::compute; -use arrow::legacy::prelude::FromData; use num_traits::{NumCast, ToPrimitive}; use polars_compute::comparisons::TotalOrdKernel; diff --git a/crates/polars-core/src/chunked_array/iterator/mod.rs b/crates/polars-core/src/chunked_array/iterator/mod.rs index d5e26f80241a6..27781b8928767 100644 --- a/crates/polars-core/src/chunked_array/iterator/mod.rs +++ b/crates/polars-core/src/chunked_array/iterator/mod.rs @@ -3,7 +3,6 @@ use arrow::array::*; use crate::prelude::*; #[cfg(feature = "dtype-struct")] use crate::series::iterator::SeriesIter; -use crate::utils::CustomIterTools; pub mod par; diff --git a/crates/polars-core/src/chunked_array/list/iterator.rs b/crates/polars-core/src/chunked_array/list/iterator.rs index f55812b4c47f4..890e156fc3f0c 100644 --- a/crates/polars-core/src/chunked_array/list/iterator.rs +++ b/crates/polars-core/src/chunked_array/list/iterator.rs @@ -4,7 +4,6 @@ use std::ptr::NonNull; use crate::prelude::*; use crate::series::unstable::{ArrayBox, UnstableSeries}; -use crate::utils::CustomIterTools; pub struct AmortizedListIter<'a, I: Iterator>> { len: usize, diff --git a/crates/polars-core/src/chunked_array/list/mod.rs b/crates/polars-core/src/chunked_array/list/mod.rs index 394df87caa2d8..eba65e8980ab4 100644 --- a/crates/polars-core/src/chunked_array/list/mod.rs +++ b/crates/polars-core/src/chunked_array/list/mod.rs @@ -30,6 +30,7 @@ impl ListChunked { } /// Set the logical type of the [`ListChunked`]. + /// /// # Safety /// The caller must ensure that the logical type given fits the physical type of the array. pub unsafe fn to_logical(&mut self, inner_dtype: DataType) { diff --git a/crates/polars-core/src/chunked_array/logical/categorical/builder.rs b/crates/polars-core/src/chunked_array/logical/categorical/builder.rs index 6e1db68ae178a..70c4167f80c6f 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/builder.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/builder.rs @@ -3,7 +3,6 @@ use arrow::legacy::trusted_len::TrustedLenPush; use hashbrown::hash_map::Entry; use polars_utils::iter::EnumerateIdxTrait; -use crate::datatypes::PlHashMap; use crate::hashing::_HASHMAP_INIT_SIZE; use crate::prelude::*; use crate::{using_string_cache, StringCache, POOL}; @@ -359,18 +358,12 @@ impl CategoricalChunked { map.insert(cat, idx as u32); } // Find idx of every value in the map - let mut keys: UInt32Chunked = values - .into_iter() - .map(|opt_s: Option<&str>| { - opt_s - .map(|s| { - map.get(s).copied().ok_or_else(|| { - polars_err!(not_in_enum, value = s, categories = categories) - }) - }) - .transpose() - }) - .collect::>()?; + let iter = values.downcast_iter().map(|arr| { + arr.iter() + .map(|opt_s: Option<&str>| opt_s.and_then(|s| map.get(s).copied())) + .collect_arr() + }); + let mut keys: UInt32Chunked = ChunkedArray::from_chunk_iter(values.name(), iter); keys.rename(values.name()); let rev_map = RevMapping::build_local(categories.clone()); unsafe { @@ -386,7 +379,6 @@ impl CategoricalChunked { #[cfg(test)] mod test { - use crate::chunked_array::categorical::CategoricalChunkedBuilder; use crate::prelude::*; use crate::{disable_string_cache, enable_string_cache, SINGLE_LOCK}; diff --git a/crates/polars-core/src/chunked_array/logical/categorical/from.rs b/crates/polars-core/src/chunked_array/logical/categorical/from.rs index c34c2eccc183e..4be936d79555a 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/from.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/from.rs @@ -1,4 +1,3 @@ -use arrow::array::DictionaryArray; use arrow::compute::cast::{cast, utf8view_to_utf8, CastOptions}; use arrow::datatypes::IntegerType; diff --git a/crates/polars-core/src/chunked_array/logical/categorical/merge.rs b/crates/polars-core/src/chunked_array/logical/categorical/merge.rs index 21ffcc0e4d5cd..375f8cc3e72f4 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/merge.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/merge.rs @@ -1,5 +1,4 @@ use std::borrow::Cow; -use std::sync::Arc; use super::*; use crate::series::IsSorted; diff --git a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs index 715ff619e5fba..959c1f5ec6668 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs @@ -105,7 +105,7 @@ impl CategoricalChunked { self.get_ordering(), ) }; - out.set_fast_unique(self.can_fast_unique()); + out.set_fast_unique(self._can_fast_unique()); out } @@ -130,18 +130,18 @@ impl CategoricalChunked { } } - // Convert to fixed enum. In case a value is not in the categories return Error - pub fn to_enum(&self, categories: &Utf8ViewArray, hash: u128) -> PolarsResult { + // Convert to fixed enum. Values not in categories are mapped to None. + pub fn to_enum(&self, categories: &Utf8ViewArray, hash: u128) -> Self { // Fast paths match self.get_rev_map().as_ref() { RevMapping::Local(_, cur_hash) if hash == *cur_hash => { return unsafe { - Ok(CategoricalChunked::from_cats_and_rev_map_unchecked( + CategoricalChunked::from_cats_and_rev_map_unchecked( self.physical().clone(), self.get_rev_map().clone(), true, self.get_ordering(), - )) + ) }; }, _ => (), @@ -159,34 +159,18 @@ impl CategoricalChunked { let new_phys: UInt32Chunked = self .physical() .into_iter() - .map(|opt_v: Option| { - let Some(v) = opt_v else { - return Ok(None); - }; - - let Some(idx) = idx_map.get(&v) else { - polars_bail!( - not_in_enum, - value = old_rev_map.get(v), - categories = &categories - ); - }; + .map(|opt_v: Option| opt_v.and_then(|v| idx_map.get(&v).copied())) + .collect(); - Ok(Some(*idx)) - }) - .collect::>()?; - - Ok( - // SAFETY: we created the physical from the enum categories - unsafe { - CategoricalChunked::from_cats_and_rev_map_unchecked( - new_phys, - Arc::new(RevMapping::Local(categories.clone(), hash)), - true, - self.get_ordering(), - ) - }, - ) + // SAFETY: we created the physical from the enum categories + unsafe { + CategoricalChunked::from_cats_and_rev_map_unchecked( + new_phys, + Arc::new(RevMapping::Local(categories.clone(), hash)), + true, + self.get_ordering(), + ) + } } pub(crate) fn get_flags(&self) -> Settings { @@ -274,7 +258,9 @@ impl CategoricalChunked { } } - pub(crate) fn can_fast_unique(&self) -> bool { + /// True if all categories are represented in this array. When this is the case, the unique + /// values of the array are the categories. + pub fn _can_fast_unique(&self) -> bool { self.bit_settings.contains(BitSettings::ORIGINAL) && self.physical.chunks.len() == 1 && self.null_count() == 0 @@ -371,7 +357,7 @@ impl LogicalType for CategoricalChunked { polars_bail!(ComputeError: "can not cast to enum with global mapping") }; Ok(self - .to_enum(categories, *hash)? + .to_enum(categories, *hash) .set_ordering(*ordering, true) .into_series() .with_name(self.name())) @@ -448,8 +434,6 @@ impl<'a> ExactSizeIterator for CatIter<'a> {} #[cfg(test)] mod test { - use std::convert::TryFrom; - use super::*; use crate::{disable_string_cache, enable_string_cache, SINGLE_LOCK}; diff --git a/crates/polars-core/src/chunked_array/logical/categorical/ops/unique.rs b/crates/polars-core/src/chunked_array/logical/categorical/ops/unique.rs index c30a482fcd186..a5477fccb3764 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/ops/unique.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/ops/unique.rs @@ -1,10 +1,9 @@ use super::*; -use crate::frame::group_by::IntoGroupsProxy; impl CategoricalChunked { pub fn unique(&self) -> PolarsResult { let cat_map = self.get_rev_map(); - if self.can_fast_unique() { + if self._can_fast_unique() { let ca = match &**cat_map { RevMapping::Local(a, _) => { UInt32Chunked::from_iter_values(self.physical().name(), 0..(a.len() as u32)) @@ -41,7 +40,7 @@ impl CategoricalChunked { } pub fn n_unique(&self) -> PolarsResult { - if self.can_fast_unique() { + if self._can_fast_unique() { Ok(self.get_rev_map().len()) } else { self.physical().n_unique() @@ -66,7 +65,7 @@ impl CategoricalChunked { let mut counts = groups.group_count(); counts.rename("counts"); let cols = vec![values.into_series(), counts.into_series()]; - let df = DataFrame::new_no_checks(cols); + let df = unsafe { DataFrame::new_no_checks(cols) }; df.sort(["counts"], true, false) } } diff --git a/crates/polars-core/src/chunked_array/logical/mod.rs b/crates/polars-core/src/chunked_array/logical/mod.rs index 4e3cdcb3a602e..33191cfafd3c5 100644 --- a/crates/polars-core/src/chunked_array/logical/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/mod.rs @@ -33,7 +33,7 @@ pub use time::*; use crate::prelude::*; -/// Maps a logical type to a a chunked array implementation of the physical type. +/// Maps a logical type to a chunked array implementation of the physical type. /// This saves a lot of compiler bloat and allows us to reuse functionality. pub struct Logical( pub ChunkedArray, @@ -44,7 +44,7 @@ pub struct Logical( impl Clone for Logical { fn clone(&self) -> Self { let mut new = Logical::::new_logical(self.0.clone()); - new.2 = self.2.clone(); + new.2.clone_from(&self.2); new } } diff --git a/crates/polars-core/src/chunked_array/logical/struct_/from.rs b/crates/polars-core/src/chunked_array/logical/struct_/from.rs index 23e23410b3f0c..4ec5707672828 100644 --- a/crates/polars-core/src/chunked_array/logical/struct_/from.rs +++ b/crates/polars-core/src/chunked_array/logical/struct_/from.rs @@ -4,11 +4,11 @@ impl From for DataFrame { fn from(ca: StructChunked) -> Self { #[cfg(feature = "object")] { - DataFrame::new_no_checks(ca.fields.clone()) + unsafe { DataFrame::new_no_checks(ca.fields.clone()) } } #[cfg(not(feature = "object"))] { - DataFrame::new_no_checks(ca.fields) + unsafe { DataFrame::new_no_checks(ca.fields) } } } } diff --git a/crates/polars-core/src/chunked_array/mod.rs b/crates/polars-core/src/chunked_array/mod.rs index b7bf7eceeb746..247c132f3a2e6 100644 --- a/crates/polars-core/src/chunked_array/mod.rs +++ b/crates/polars-core/src/chunked_array/mod.rs @@ -52,7 +52,7 @@ use arrow::legacy::prelude::*; use bitflags::bitflags; use crate::series::IsSorted; -use crate::utils::{first_non_null, last_non_null, CustomIterTools}; +use crate::utils::{first_non_null, last_non_null}; #[cfg(not(feature = "dtype-categorical"))] pub struct RevMapping {} @@ -189,6 +189,11 @@ impl ChunkedArray { self.bit_settings.contains(Settings::SORTED_DSC) } + /// Whether `self` is sorted in any direction. + pub(crate) fn is_sorted_any(&self) -> bool { + self.is_sorted_ascending_flag() || self.is_sorted_descending_flag() + } + pub fn unset_fast_explode_list(&mut self) { self.bit_settings.remove(Settings::FAST_EXPLODE_LIST) } @@ -220,8 +225,28 @@ impl ChunkedArray { /// Get the index of the first non null value in this [`ChunkedArray`]. pub fn first_non_null(&self) -> Option { - if self.is_empty() { + if self.null_count() == self.len() { None + } + // We now know there is at least 1 non-null item in the array, and self.len() > 0 + else if self.null_count() == 0 { + Some(0) + } else if self.is_sorted_any() { + let out = if unsafe { self.downcast_get_unchecked(0).is_null_unchecked(0) } { + // nulls are all at the start + self.null_count() + } else { + // nulls are all at the end + 0 + }; + + debug_assert!( + // If we are lucky this catches something. + unsafe { self.get_unchecked(out) }.is_some(), + "incorrect sorted flag" + ); + + Some(out) } else { first_non_null(self.iter_validities()) } @@ -229,7 +254,31 @@ impl ChunkedArray { /// Get the index of the last non null value in this [`ChunkedArray`]. pub fn last_non_null(&self) -> Option { - last_non_null(self.iter_validities(), self.length as usize) + if self.null_count() == self.len() { + None + } + // We now know there is at least 1 non-null item in the array, and self.len() > 0 + else if self.null_count() == 0 { + Some(self.len() - 1) + } else if self.is_sorted_any() { + let out = if unsafe { self.downcast_get_unchecked(0).is_null_unchecked(0) } { + // nulls are all at the start + self.len() - 1 + } else { + // nulls are all at the end + self.len() - self.null_count() - 1 + }; + + debug_assert!( + // If we are lucky this catches something. + unsafe { self.get_unchecked(out) }.is_some(), + "incorrect sorted flag" + ); + + Some(out) + } else { + last_non_null(self.iter_validities(), self.len()) + } } /// Get the buffer of bits representing null values diff --git a/crates/polars-core/src/chunked_array/object/builder.rs b/crates/polars-core/src/chunked_array/object/builder.rs index cb5348bb5e08a..1bf6727342a9f 100644 --- a/crates/polars-core/src/chunked_array/object/builder.rs +++ b/crates/polars-core/src/chunked_array/object/builder.rs @@ -1,5 +1,4 @@ use std::marker::PhantomData; -use std::sync::Arc; use arrow::bitmap::MutableBitmap; diff --git a/crates/polars-core/src/chunked_array/object/extension/mod.rs b/crates/polars-core/src/chunked_array/object/extension/mod.rs index 18b4bb554a769..df5797769b915 100644 --- a/crates/polars-core/src/chunked_array/object/extension/mod.rs +++ b/crates/polars-core/src/chunked_array/object/extension/mod.rs @@ -4,7 +4,7 @@ pub(crate) mod polars_extension; use std::mem; -use arrow::array::{Array, FixedSizeBinaryArray}; +use arrow::array::FixedSizeBinaryArray; use arrow::bitmap::MutableBitmap; use arrow::buffer::Buffer; use polars_extension::PolarsExtension; @@ -133,7 +133,9 @@ pub(crate) fn create_extension> + TrustedLen, T: Si #[cfg(test)] mod test { use std::fmt::{Display, Formatter}; + use std::hash::{Hash, Hasher}; + use polars_utils::total_ord::TotalHash; use polars_utils::unitvec; use super::*; @@ -151,6 +153,15 @@ mod test { } } + impl TotalHash for Foo { + fn tot_hash(&self, state: &mut H) + where + H: Hasher, + { + self.hash(state); + } + } + impl Display for Foo { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "{:?}", self) diff --git a/crates/polars-core/src/chunked_array/object/extension/polars_extension.rs b/crates/polars-core/src/chunked_array/object/extension/polars_extension.rs index 4cd2de7abad9f..6030f668dfe1a 100644 --- a/crates/polars-core/src/chunked_array/object/extension/polars_extension.rs +++ b/crates/polars-core/src/chunked_array/object/extension/polars_extension.rs @@ -1,7 +1,5 @@ use std::mem::ManuallyDrop; -use arrow::array::FixedSizeBinaryArray; - use super::*; use crate::prelude::*; diff --git a/crates/polars-core/src/chunked_array/object/mod.rs b/crates/polars-core/src/chunked_array/object/mod.rs index 65fb98b4d96e1..9f17a1d1b4347 100644 --- a/crates/polars-core/src/chunked_array/object/mod.rs +++ b/crates/polars-core/src/chunked_array/object/mod.rs @@ -4,6 +4,7 @@ use std::hash::Hash; use arrow::bitmap::utils::{BitmapIter, ZipValidity}; use arrow::bitmap::Bitmap; +use polars_utils::total_ord::TotalHash; use crate::prelude::*; @@ -36,7 +37,7 @@ pub trait PolarsObjectSafe: Any + Debug + Send + Sync + Display { /// Values need to implement this so that they can be stored into a Series and DataFrame pub trait PolarsObject: - Any + Debug + Clone + Send + Sync + Default + Display + Hash + PartialEq + Eq + TotalEq + Any + Debug + Clone + Send + Sync + Default + Display + Hash + TotalHash + PartialEq + Eq + TotalEq { /// This should be used as type information. Consider this a part of the type system. fn type_name() -> &'static str; diff --git a/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs b/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs index a8db5671eba29..c88795582f17a 100644 --- a/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/aggregate/mod.rs @@ -13,6 +13,9 @@ use polars_utils::min_max::MinMax; pub use quantile::*; pub use var::*; +use super::float_sorted_arg_max::{ + float_arg_max_sorted_ascending, float_arg_max_sorted_descending, +}; use crate::chunked_array::ChunkedArray; use crate::datatypes::{BooleanChunked, PolarsNumericType}; use crate::prelude::*; @@ -93,21 +96,18 @@ where } fn min(&self) -> Option { - if self.is_empty() { + if self.null_count() == self.len() { return None; } + // There is at least one non-null value. match self.is_sorted_flag() { IsSorted::Ascending => { - self.first_non_null().and_then(|idx| { - // SAFETY: first_non_null returns in bound index. - unsafe { self.get_unchecked(idx) } - }) + let idx = self.first_non_null().unwrap(); + unsafe { self.get_unchecked(idx) } }, IsSorted::Descending => { - self.last_non_null().and_then(|idx| { - // SAFETY: last returns in bound index. - unsafe { self.get_unchecked(idx) } - }) + let idx = self.last_non_null().unwrap(); + unsafe { self.get_unchecked(idx) } }, IsSorted::Not => self .downcast_iter() @@ -117,23 +117,28 @@ where } fn max(&self) -> Option { - if self.is_empty() { + if self.null_count() == self.len() { return None; } + // There is at least one non-null value. match self.is_sorted_flag() { IsSorted::Ascending => { - self.last_non_null().and_then(|idx| { - // SAFETY: - // last_non_null returns in bound index - unsafe { self.get_unchecked(idx) } - }) + let idx = if T::get_dtype().is_float() { + float_arg_max_sorted_ascending(self) + } else { + self.last_non_null().unwrap() + }; + + unsafe { self.get_unchecked(idx) } }, IsSorted::Descending => { - self.first_non_null().and_then(|idx| { - // SAFETY: - // first_non_null returns in bound index - unsafe { self.get_unchecked(idx) } - }) + let idx = if T::get_dtype().is_float() { + float_arg_max_sorted_descending(self) + } else { + self.first_non_null().unwrap() + }; + + unsafe { self.get_unchecked(idx) } }, IsSorted::Not => self .downcast_iter() @@ -143,30 +148,36 @@ where } fn min_max(&self) -> Option<(T::Native, T::Native)> { - if self.is_empty() { + if self.null_count() == self.len() { return None; } + // There is at least one non-null value. match self.is_sorted_flag() { IsSorted::Ascending => { - let min = self.first_non_null().and_then(|idx| { - // SAFETY: first_non_null returns in bound index. - unsafe { self.get_unchecked(idx) } - }); - let max = self.last_non_null().and_then(|idx| { - // SAFETY: last_non_null returns in bound index. + let min = unsafe { self.get_unchecked(self.first_non_null().unwrap()) }; + let max = { + let idx = if T::get_dtype().is_float() { + float_arg_max_sorted_ascending(self) + } else { + self.last_non_null().unwrap() + }; + unsafe { self.get_unchecked(idx) } - }); + }; min.zip(max) }, IsSorted::Descending => { - let max = self.first_non_null().and_then(|idx| { - // SAFETY: first_non_null returns in bound index. - unsafe { self.get_unchecked(idx) } - }); - let min = self.last_non_null().and_then(|idx| { - // SAFETY: last_non_null returns in bound index. + let min = unsafe { self.get_unchecked(self.last_non_null().unwrap()) }; + let max = { + let idx = if T::get_dtype().is_float() { + float_arg_max_sorted_descending(self) + } else { + self.first_non_null().unwrap() + }; + unsafe { self.get_unchecked(idx) } - }); + }; + min.zip(max) }, IsSorted::Not => self @@ -182,10 +193,10 @@ where } fn mean(&self) -> Option { - if self.is_empty() || self.null_count() == self.len() { + if self.null_count() == self.len() { return None; } - match self.dtype() { + match T::get_dtype() { DataType::Float64 => { let len = (self.len() - self.null_count()) as f64; self.sum().map(|v| v.to_f64().unwrap() / len) @@ -522,7 +533,7 @@ impl CategoricalChunked { } if self.uses_lexical_ordering() { // Fast path where all categories are used - if self.can_fast_unique() { + if self._can_fast_unique() { self.get_rev_map().get_categories().min_ignore_nan_kernel() } else { let rev_map = self.get_rev_map(); @@ -550,7 +561,7 @@ impl CategoricalChunked { } if self.uses_lexical_ordering() { // Fast path where all categories are used - if self.can_fast_unique() { + if self._can_fast_unique() { self.get_rev_map().get_categories().max_ignore_nan_kernel() } else { let rev_map = self.get_rev_map(); @@ -650,8 +661,6 @@ impl ChunkAggSeries for ObjectChunked {} #[cfg(test)] mod test { - use arrow::legacy::prelude::QuantileInterpolOptions; - use crate::prelude::*; #[test] diff --git a/crates/polars-core/src/chunked_array/ops/aggregate/quantile.rs b/crates/polars-core/src/chunked_array/ops/aggregate/quantile.rs index 691f66e568c4d..ce528337f0c12 100644 --- a/crates/polars-core/src/chunked_array/ops/aggregate/quantile.rs +++ b/crates/polars-core/src/chunked_array/ops/aggregate/quantile.rs @@ -1,5 +1,3 @@ -use arrow::legacy::prelude::QuantileInterpolOptions; - use super::*; pub trait QuantileAggSeries { diff --git a/crates/polars-core/src/chunked_array/ops/append.rs b/crates/polars-core/src/chunked_array/ops/append.rs index 027ccb09d1682..3325519eddb67 100644 --- a/crates/polars-core/src/chunked_array/ops/append.rs +++ b/crates/polars-core/src/chunked_array/ops/append.rs @@ -4,7 +4,7 @@ use crate::series::IsSorted; pub(crate) fn new_chunks(chunks: &mut Vec, other: &[ArrayRef], len: usize) { // Replace an empty array. if chunks.len() == 1 && len == 0 { - *chunks = other.to_owned(); + other.clone_into(chunks); } else { for chunk in other { if chunk.len() > 0 { @@ -19,51 +19,114 @@ where T: PolarsDataType, for<'a> T::Physical<'a>: TotalOrd, { - // TODO: attempt to maintain sortedness better in case of nulls. + // Note: Do not call (first|last)_non_null on an array here before checking + // it is sorted, otherwise it will lead to quadratic behavior. + let sorted_flag = match ( + ca.null_count() != ca.len(), + other.null_count() != other.len(), + ) { + (false, false) => IsSorted::Ascending, + (false, true) => { + if + // lhs is empty, just take sorted flag from rhs + ca.is_empty() + || ( + // lhs is non-empty and all-null, so rhs must have nulls ordered first + other.is_sorted_any() && 1 + other.last_non_null().unwrap() == other.len() + ) + { + other.is_sorted_flag() + } else { + IsSorted::Not + } + }, + (true, false) => { + if + // rhs is empty, just take sorted flag from lhs + other.is_empty() + || ( + // rhs is non-empty and all-null, so lhs must have nulls ordered last + ca.is_sorted_any() && ca.first_non_null().unwrap() == 0 + ) + { + ca.is_sorted_flag() + } else { + IsSorted::Not + } + }, + (true, true) => { + // both arrays have non-null values. + // for arrays of unit length we can ignore the sorted flag, as it is + // not necessarily set. + if !(ca.is_sorted_any() || ca.len() == 1) + || !(other.is_sorted_any() || other.len() == 1) + || !( + // We will coerce for single values + ca.len() - ca.null_count() == 1 + || other.len() - other.null_count() == 1 + || ca.is_sorted_flag() == other.is_sorted_flag() + ) + { + IsSorted::Not + } else { + let l_idx = ca.last_non_null().unwrap(); + let r_idx = other.first_non_null().unwrap(); - // If either is empty, copy the sorted flag from the other. - if ca.is_empty() { - ca.set_sorted_flag(other.is_sorted_flag()); - return; - } - if other.is_empty() { - return; - } + let l_val = unsafe { ca.value_unchecked(l_idx) }; + let r_val = unsafe { other.value_unchecked(r_idx) }; - // Both need to be sorted, in the same order, if the order is maintained. - // TODO: rework sorted flags, ascending and descending are not mutually - // exclusive for all-equal/all-null arrays. - let ls = ca.is_sorted_flag(); - let rs = other.is_sorted_flag(); - if ls != rs || ls == IsSorted::Not || rs == IsSorted::Not { - ca.set_sorted_flag(IsSorted::Not); - return; - } + let null_pos_check = + // check null positions + // lhs does not end in nulls + (1 + l_idx == ca.len()) + // rhs does not start with nulls + && (r_idx == 0) + // if there are nulls, they are all on one end + && !(ca.first_non_null().unwrap() != 0 && 1 + other.last_non_null().unwrap() != other.len()); - // Check the order is maintained. - let still_sorted = { - // To prevent potential quadratic append behavior we do not find - // the last non-null element in ca. - if let Some(left) = ca.last() { - if let Some(right_idx) = other.first_non_null() { - let right = other.get(right_idx).unwrap(); - if ca.is_sorted_ascending_flag() { - left.tot_le(&right) + if !null_pos_check { + IsSorted::Not } else { - left.tot_ge(&right) + #[allow(unused_assignments)] + let mut out = IsSorted::Not; + + #[allow(clippy::never_loop)] + loop { + match ( + ca.len() - ca.null_count() == 1, + other.len() - other.null_count() == 1, + ) { + (true, true) => { + out = [IsSorted::Descending, IsSorted::Ascending] + [l_val.tot_le(&r_val) as usize]; + break; + }, + (true, false) => out = other.is_sorted_flag(), + _ => out = ca.is_sorted_flag(), + } + + debug_assert!(!matches!(out, IsSorted::Not)); + + let check = if matches!(out, IsSorted::Ascending) { + l_val.tot_le(&r_val) + } else { + l_val.tot_ge(&r_val) + }; + + if !check { + out = IsSorted::Not + } + + break; + } + + out } - } else { - // Right is only nulls, trivially sorted. - true } - } else { - // Last element in left is null, pessimistically assume not sorted. - false - } + }, }; - if !still_sorted { - ca.set_sorted_flag(IsSorted::Not); - } + + ca.set_sorted_flag(sorted_flag); } impl ChunkedArray diff --git a/crates/polars-core/src/chunked_array/ops/apply.rs b/crates/polars-core/src/chunked_array/ops/apply.rs index 9f8acfeedb24b..c4d77df7e09e4 100644 --- a/crates/polars-core/src/chunked_array/ops/apply.rs +++ b/crates/polars-core/src/chunked_array/ops/apply.rs @@ -1,14 +1,11 @@ //! Implementations of the ChunkApply Trait. use std::borrow::Cow; -use std::convert::TryFrom; -use arrow::array::{BooleanArray, PrimitiveArray}; use arrow::bitmap::utils::{get_bit_unchecked, set_bit_unchecked}; use arrow::legacy::bitmap::unary_mut; use crate::prelude::*; use crate::series::IsSorted; -use crate::utils::CustomIterTools; impl ChunkedArray where diff --git a/crates/polars-core/src/chunked_array/ops/chunkops.rs b/crates/polars-core/src/chunked_array/ops/chunkops.rs index f2cdc41d51c2e..6a061c280eaf0 100644 --- a/crates/polars-core/src/chunked_array/ops/chunkops.rs +++ b/crates/polars-core/src/chunked_array/ops/chunkops.rs @@ -1,5 +1,3 @@ -#[cfg(feature = "object")] -use arrow::array::Array; use arrow::legacy::kernels::concatenate::concatenate_owned_unchecked; use polars_error::constants::LENGTH_LIMIT_MSG; @@ -92,7 +90,10 @@ impl ChunkedArray { _ => chunks.iter().fold(0, |acc, arr| acc + arr.len()), } } - self.length = IdxSize::try_from(inner(&self.chunks)).expect(LENGTH_LIMIT_MSG); + let len = inner(&self.chunks); + // Length limit is `IdxSize::MAX - 1`. We use `IdxSize::MAX` to indicate `NULL` in indexing. + assert!(len < IdxSize::MAX as usize, "{}", LENGTH_LIMIT_MSG); + self.length = len as IdxSize; self.null_count = self .chunks .iter() diff --git a/crates/polars-core/src/chunked_array/ops/explode.rs b/crates/polars-core/src/chunked_array/ops/explode.rs index 573b4379158b4..6107b10ab6969 100644 --- a/crates/polars-core/src/chunked_array/ops/explode.rs +++ b/crates/polars-core/src/chunked_array/ops/explode.rs @@ -1,9 +1,6 @@ -use std::convert::TryFrom; - use arrow::array::*; use arrow::bitmap::{Bitmap, MutableBitmap}; use arrow::legacy::array::list::AnonymousBuilder; -use arrow::legacy::array::PolarsArray; use arrow::legacy::bit_util::unset_bit_raw; #[cfg(feature = "dtype-array")] use arrow::legacy::is_valid::IsValid; diff --git a/crates/polars-core/src/chunked_array/ops/fill_null.rs b/crates/polars-core/src/chunked_array/ops/fill_null.rs index 9458021cf92dc..e52cef3c296ec 100644 --- a/crates/polars-core/src/chunked_array/ops/fill_null.rs +++ b/crates/polars-core/src/chunked_array/ops/fill_null.rs @@ -1,6 +1,6 @@ use arrow::legacy::kernels::set::set_at_nulls; use arrow::legacy::trusted_len::FromIteratorReversed; -use arrow::legacy::utils::{CustomIterTools, FromTrustedLenIterator}; +use arrow::legacy::utils::FromTrustedLenIterator; use num_traits::{Bounded, NumCast, One, Zero}; use crate::prelude::*; diff --git a/crates/polars-core/src/chunked_array/ops/filter.rs b/crates/polars-core/src/chunked_array/ops/filter.rs index b07b9703b3887..1ab76a60077f1 100644 --- a/crates/polars-core/src/chunked_array/ops/filter.rs +++ b/crates/polars-core/src/chunked_array/ops/filter.rs @@ -1,5 +1,3 @@ -#[cfg(feature = "object")] -use arrow::array::Array; use polars_compute::filter::filter as filter_fn; #[cfg(feature = "object")] diff --git a/crates/polars-core/src/chunked_array/ops/float_sorted_arg_max.rs b/crates/polars-core/src/chunked_array/ops/float_sorted_arg_max.rs new file mode 100644 index 0000000000000..0ec721f1e40a4 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/float_sorted_arg_max.rs @@ -0,0 +1,87 @@ +//! Implementations of the ChunkAgg trait. +use num_traits::Float; + +use self::search_sorted::{ + binary_search_array, slice_sorted_non_null_and_offset, SearchSortedSide, +}; +use crate::prelude::*; + +impl ChunkedArray +where + T: PolarsFloatType, + T::Native: Float, +{ + fn float_arg_max_sorted_ascending(&self) -> usize { + let ca = self; + debug_assert!(ca.is_sorted_ascending_flag()); + let is_descending = false; + let side = SearchSortedSide::Left; + + let maybe_max_idx = ca.last_non_null().unwrap(); + + let maybe_max = unsafe { ca.value_unchecked(maybe_max_idx) }; + if !maybe_max.is_nan() { + return maybe_max_idx; + } + + let (offset, ca) = unsafe { slice_sorted_non_null_and_offset(ca) }; + let arr = unsafe { ca.downcast_get_unchecked(0) }; + let search_val = T::Native::nan(); + let idx = binary_search_array(side, arr, search_val, is_descending) as usize; + + let idx = idx.saturating_sub(1); + + offset + idx + } + + fn float_arg_max_sorted_descending(&self) -> usize { + let ca = self; + debug_assert!(ca.is_sorted_descending_flag()); + let is_descending = true; + let side = SearchSortedSide::Right; + + let maybe_max_idx = ca.first_non_null().unwrap(); + + let maybe_max = unsafe { ca.value_unchecked(maybe_max_idx) }; + if !maybe_max.is_nan() { + return maybe_max_idx; + } + + let (offset, ca) = unsafe { slice_sorted_non_null_and_offset(ca) }; + let arr = unsafe { ca.downcast_get_unchecked(0) }; + let search_val = T::Native::nan(); + let idx = binary_search_array(side, arr, search_val, is_descending) as usize; + + let idx = if idx == arr.len() { idx - 1 } else { idx }; + + offset + idx + } +} + +/// # Safety +/// `ca` has a float dtype, has at least 1 non-null value and is sorted ascending +pub fn float_arg_max_sorted_ascending(ca: &ChunkedArray) -> usize +where + T: PolarsNumericType, +{ + with_match_physical_float_polars_type!(ca.dtype(), |$T| { + let ca: &ChunkedArray<$T> = unsafe { + &*(ca as *const ChunkedArray as *const ChunkedArray<$T>) + }; + ca.float_arg_max_sorted_ascending() + }) +} + +/// # Safety +/// `ca` has a float dtype, has at least 1 non-null value and is sorted descending +pub fn float_arg_max_sorted_descending(ca: &ChunkedArray) -> usize +where + T: PolarsNumericType, +{ + with_match_physical_float_polars_type!(ca.dtype(), |$T| { + let ca: &ChunkedArray<$T> = unsafe { + &*(ca as *const ChunkedArray as *const ChunkedArray<$T>) + }; + ca.float_arg_max_sorted_descending() + }) +} diff --git a/crates/polars-core/src/chunked_array/ops/full.rs b/crates/polars-core/src/chunked_array/ops/full.rs index 21616823fe79c..16ba9d5f0ba8d 100644 --- a/crates/polars-core/src/chunked_array/ops/full.rs +++ b/crates/polars-core/src/chunked_array/ops/full.rs @@ -1,5 +1,4 @@ use arrow::bitmap::MutableBitmap; -use arrow::legacy::array::default_arrays::FromData; use crate::chunked_array::builder::get_list_builder; use crate::prelude::*; diff --git a/crates/polars-core/src/chunked_array/ops/gather.rs b/crates/polars-core/src/chunked_array/ops/gather.rs index 5db4c28ece6c8..fbc2d827db7fd 100644 --- a/crates/polars-core/src/chunked_array/ops/gather.rs +++ b/crates/polars-core/src/chunked_array/ops/gather.rs @@ -1,14 +1,12 @@ -use arrow::array::Array; use arrow::bitmap::bitmask::BitMask; +use arrow::bitmap::Bitmap; use arrow::compute::take::take_unchecked; -use polars_error::{polars_bail, polars_ensure, PolarsResult}; +use polars_error::polars_ensure; use polars_utils::index::check_bounds; use crate::chunked_array::collect::prepare_collect_dtype; -use crate::chunked_array::ops::{ChunkTake, ChunkTakeUnchecked}; -use crate::chunked_array::ChunkedArray; -use crate::datatypes::{IdxCa, PolarsDataType, StaticArray}; use crate::prelude::*; +use crate::series::IsSorted; const BINARY_SEARCH_LIMIT: usize = 8; @@ -187,6 +185,18 @@ impl NotSpecialized for DecimalType {} #[cfg(feature = "object")] impl NotSpecialized for ObjectType {} +pub fn _update_gather_sorted_flag(sorted_arr: IsSorted, sorted_idx: IsSorted) -> IsSorted { + use crate::series::IsSorted::*; + match (sorted_arr, sorted_idx) { + (_, Not) => Not, + (Not, _) => Not, + (Ascending, Ascending) => Ascending, + (Ascending, Descending) => Descending, + (Descending, Ascending) => Descending, + (Descending, Descending) => Ascending, + } +} + impl ChunkTakeUnchecked for ChunkedArray { /// Gather values from ChunkedArray by index. unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self { @@ -233,16 +243,8 @@ impl ChunkTakeUnchecked for ChunkedAr }); let mut out = ChunkedArray::from_chunk_iter_like(ca, chunks); + let sorted_flag = _update_gather_sorted_flag(ca.is_sorted_flag(), indices.is_sorted_flag()); - use crate::series::IsSorted::*; - let sorted_flag = match (ca.is_sorted_flag(), indices.is_sorted_flag()) { - (_, Not) => Not, - (Not, _) => Not, - (Ascending, Ascending) => Ascending, - (Ascending, Descending) => Descending, - (Descending, Ascending) => Descending, - (Descending, Descending) => Ascending, - }; out.set_sorted_flag(sorted_flag); out } @@ -262,15 +264,8 @@ impl ChunkTakeUnchecked for BinaryChunked { let mut out = ChunkedArray::from_chunks(self.name(), chunks); - use crate::series::IsSorted::*; - let sorted_flag = match (self.is_sorted_flag(), indices.is_sorted_flag()) { - (_, Not) => Not, - (Not, _) => Not, - (Ascending, Ascending) => Ascending, - (Ascending, Descending) => Descending, - (Descending, Ascending) => Descending, - (Descending, Descending) => Ascending, - }; + let sorted_flag = + _update_gather_sorted_flag(self.is_sorted_flag(), indices.is_sorted_flag()); out.set_sorted_flag(sorted_flag); out } @@ -281,3 +276,15 @@ impl ChunkTakeUnchecked for StringChunked { self.as_binary().take_unchecked(indices).to_string() } } + +impl IdxCa { + pub fn with_nullable_idx T>(idx: &[NullableIdxSize], f: F) -> T { + let validity: Bitmap = idx.iter().map(|idx| !idx.is_null_idx()).collect_trusted(); + let idx = bytemuck::cast_slice::<_, IdxSize>(idx); + let arr = unsafe { arrow::ffi::mmap::slice(idx) }; + let arr = arr.with_validity_typed(Some(validity)); + let ca = IdxCa::with_chunk("", arr); + + f(&ca) + } +} diff --git a/crates/polars-core/src/chunked_array/ops/min_max_binary.rs b/crates/polars-core/src/chunked_array/ops/min_max_binary.rs index 279a4ae0719f3..bc33f088b1f9d 100644 --- a/crates/polars-core/src/chunked_array/ops/min_max_binary.rs +++ b/crates/polars-core/src/chunked_array/ops/min_max_binary.rs @@ -1,4 +1,3 @@ -use crate::datatypes::PolarsNumericType; use crate::prelude::*; use crate::series::arithmetic::coerce_lhs_rhs; diff --git a/crates/polars-core/src/chunked_array/ops/mod.rs b/crates/polars-core/src/chunked_array/ops/mod.rs index 9fe082c3a9da8..acd5df24bfb86 100644 --- a/crates/polars-core/src/chunked_array/ops/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/mod.rs @@ -1,9 +1,6 @@ //! Traits for miscellaneous operations on ChunkedArray -use arrow::legacy::prelude::QuantileInterpolOptions; use arrow::offset::OffsetsBuffer; -#[cfg(feature = "object")] -use crate::datatypes::ObjectType; use crate::prelude::*; pub(crate) mod aggregate; @@ -22,6 +19,7 @@ mod explode_and_offsets; mod extend; pub mod fill_null; mod filter; +pub mod float_sorted_arg_max; mod for_each; pub mod full; pub mod gather; @@ -32,6 +30,7 @@ pub(crate) mod min_max_binary; pub(crate) mod nulls; mod reverse; pub(crate) mod rolling_window; +pub mod search_sorted; mod set; mod shift; pub mod sort; diff --git a/crates/polars-core/src/chunked_array/ops/reverse.rs b/crates/polars-core/src/chunked_array/ops/reverse.rs index c5910dcc88634..62d742fe62843 100644 --- a/crates/polars-core/src/chunked_array/ops/reverse.rs +++ b/crates/polars-core/src/chunked_array/ops/reverse.rs @@ -2,7 +2,7 @@ use crate::chunked_array::builder::get_fixed_size_list_builder; use crate::prelude::*; use crate::series::IsSorted; -use crate::utils::{CustomIterTools, NoNull}; +use crate::utils::NoNull; impl ChunkReverse for ChunkedArray where diff --git a/crates/polars-core/src/chunked_array/ops/rolling_window.rs b/crates/polars-core/src/chunked_array/ops/rolling_window.rs index 3d1d7270c3887..2345f83df00d4 100644 --- a/crates/polars-core/src/chunked_array/ops/rolling_window.rs +++ b/crates/polars-core/src/chunked_array/ops/rolling_window.rs @@ -30,7 +30,6 @@ impl Default for RollingOptionsFixedWindow { mod inner_mod { use std::ops::SubAssign; - use arrow::array::{Array, PrimitiveArray}; use arrow::bitmap::MutableBitmap; use arrow::legacy::bit_util::unset_bit_raw; use arrow::legacy::trusted_len::TrustedLenPush; diff --git a/crates/polars-core/src/chunked_array/ops/search_sorted.rs b/crates/polars-core/src/chunked_array/ops/search_sorted.rs new file mode 100644 index 0000000000000..823b8f9a641eb --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/search_sorted.rs @@ -0,0 +1,128 @@ +use std::cmp::Ordering; +use std::fmt::Debug; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use crate::prelude::*; + +#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Default)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum SearchSortedSide { + #[default] + Any, + Left, + Right, +} + +/// Search the left or right index that still fulfills the requirements. +fn get_side_idx<'a, A>(side: SearchSortedSide, mid: IdxSize, arr: &'a A, len: usize) -> IdxSize +where + A: StaticArray, + A::ValueT<'a>: TotalOrd + Debug + Copy, +{ + let mut mid = mid; + + // approach the boundary from any side + // this is O(n) we could make this binary search later + match side { + SearchSortedSide::Any => mid, + SearchSortedSide::Left => { + if mid as usize == len { + mid -= 1; + } + + let current = unsafe { arr.get_unchecked(mid as usize) }; + loop { + if mid == 0 { + return mid; + } + mid -= 1; + if current.tot_ne(unsafe { &arr.get_unchecked(mid as usize) }) { + return mid + 1; + } + } + }, + SearchSortedSide::Right => { + if mid as usize == len { + return mid; + } + let current = unsafe { arr.get_unchecked(mid as usize) }; + let bound = (len - 1) as IdxSize; + loop { + if mid >= bound { + return mid + 1; + } + mid += 1; + if current.tot_ne(unsafe { &arr.get_unchecked(mid as usize) }) { + return mid; + } + } + }, + } +} + +pub fn binary_search_array<'a, A>( + side: SearchSortedSide, + arr: &'a A, + search_value: A::ValueT<'a>, + descending: bool, +) -> IdxSize +where + A: StaticArray, + A::ValueT<'a>: TotalOrd + Debug + Copy, +{ + let mut size = arr.len() as IdxSize; + let mut left = 0 as IdxSize; + let mut right = size; + while left < right { + let mid = left + size / 2; + + // SAFETY: the call is made safe by the following invariants: + // - `mid >= 0` + // - `mid < size`: `mid` is limited by `[left; right)` bound. + let cmp = match unsafe { arr.get_unchecked(mid as usize) } { + None => Ordering::Less, + Some(value) => { + if descending { + search_value.tot_cmp(&value) + } else { + value.tot_cmp(&search_value) + } + }, + }; + + // The reason why we use if/else control flow rather than match + // is because match reorders comparison operations, which is perf sensitive. + // This is x86 asm for u8: https://rust.godbolt.org/z/8Y8Pra. + if cmp == Ordering::Less { + left = mid + 1; + } else if cmp == Ordering::Greater { + right = mid; + } else { + return get_side_idx(side, mid, arr, arr.len()); + } + + size = right - left; + } + + left +} + +/// Get a slice of the non-null values of a sorted array. The returned array +/// will have a single chunk. +/// # Safety +/// The array is sorted and has at least one non-null value. +pub unsafe fn slice_sorted_non_null_and_offset(ca: &ChunkedArray) -> (usize, ChunkedArray) +where + T: PolarsDataType, +{ + let offset = ca.first_non_null().unwrap(); + let length = 1 + ca.last_non_null().unwrap() - offset; + let out = ca.slice(offset as i64, length); + + debug_assert!(out.null_count() != out.len()); + debug_assert!(out.null_count() == 0); + + (offset, out.rechunk()) +} diff --git a/crates/polars-core/src/chunked_array/ops/set.rs b/crates/polars-core/src/chunked_array/ops/set.rs index 0c9cdbd0f4aa6..52646925a05cf 100644 --- a/crates/polars-core/src/chunked_array/ops/set.rs +++ b/crates/polars-core/src/chunked_array/ops/set.rs @@ -1,9 +1,8 @@ use arrow::bitmap::MutableBitmap; use arrow::legacy::kernels::set::{scatter_single_non_null, set_with_mask}; -use arrow::legacy::prelude::FromData; use crate::prelude::*; -use crate::utils::{align_chunks_binary, CustomIterTools}; +use crate::utils::align_chunks_binary; macro_rules! impl_scatter_with { ($self:ident, $builder:ident, $idx:ident, $f:ident) => {{ diff --git a/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs b/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs index 20f2851693439..c1e2fe3791554 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs @@ -4,7 +4,6 @@ use polars_utils::iter::EnumerateIdxTrait; use super::*; #[cfg(feature = "dtype-struct")] use crate::utils::_split_offsets; -use crate::POOL; pub(crate) fn args_validate( ca: &ChunkedArray, @@ -89,17 +88,15 @@ pub(crate) fn encode_rows_vertical(by: &[Series]) -> PolarsResult> = splits - .into_par_iter() - .map(|(offset, len)| { - let sliced = by - .iter() - .map(|s| s.slice(offset as i64, len)) - .collect::>(); - let rows = _get_rows_encoded(&sliced, &descending, false)?; - Ok(rows.into_array()) - }) - .collect(); + let chunks = splits.into_par_iter().map(|(offset, len)| { + let sliced = by + .iter() + .map(|s| s.slice(offset as i64, len)) + .collect::>(); + let rows = _get_rows_encoded(&sliced, &descending, false)?; + Ok(rows.into_array()) + }); + let chunks = POOL.install(|| chunks.collect::>>()); Ok(BinaryOffsetChunked::from_chunk_iter("", chunks?)) } diff --git a/crates/polars-core/src/chunked_array/ops/sort/mod.rs b/crates/polars-core/src/chunked_array/ops/sort/mod.rs index 2b07c27ba83fa..9f7de82edc87b 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/mod.rs @@ -7,10 +7,8 @@ mod categorical; use std::cmp::Ordering; pub(crate) use arg_sort_multiple::argsort_multiple_row_fmt; -use arrow::array::ValueSize; use arrow::bitmap::MutableBitmap; use arrow::buffer::Buffer; -use arrow::legacy::prelude::FromData; use arrow::legacy::trusted_len::TrustedLenPush; use rayon::prelude::*; pub use slice::*; @@ -21,7 +19,7 @@ use crate::prelude::sort::arg_sort_multiple::_get_rows_encoded_ca; use crate::prelude::sort::arg_sort_multiple::{arg_sort_multiple_impl, args_validate}; use crate::prelude::*; use crate::series::IsSorted; -use crate::utils::{CustomIterTools, NoNull}; +use crate::utils::NoNull; use crate::POOL; pub(crate) fn sort_by_branch(slice: &mut [T], descending: bool, cmp: C, parallel: bool) @@ -631,6 +629,10 @@ pub(crate) fn convert_sort_column_multi_sort(s: &Series) -> PolarsResult .collect::>>()?; return StructChunked::new(ca.name(), &new_fields).map(|ca| ca.into_series()); }, + // we could fallback to default branch, but decimal is not numeric dtype for now, so explicit here + #[cfg(feature = "dtype-decimal")] + Decimal(_, _) => s.clone(), + List(inner) if !inner.is_nested() => s.clone(), _ => { let phys = s.to_physical_repr().into_owned(); polars_ensure!( diff --git a/crates/polars-core/src/chunked_array/ops/unique/mod.rs b/crates/polars-core/src/chunked_array/ops/unique/mod.rs index 34e6946f7e7f9..3820475786826 100644 --- a/crates/polars-core/src/chunked_array/ops/unique/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/unique/mod.rs @@ -1,11 +1,8 @@ use std::hash::Hash; use arrow::bitmap::MutableBitmap; +use polars_utils::total_ord::{ToTotalOrd, TotalHash}; -#[cfg(feature = "object")] -use crate::datatypes::ObjectType; -use crate::datatypes::PlHashSet; -use crate::frame::group_by::GroupsProxy; use crate::hashing::_HASHMAP_INIT_SIZE; use crate::prelude::*; use crate::series::IsSorted; @@ -60,12 +57,13 @@ impl ChunkUnique> for ObjectChunked { fn arg_unique(a: impl Iterator, capacity: usize) -> Vec where - T: Hash + Eq, + T: ToTotalOrd, + ::TotalOrdItem: Hash + Eq, { let mut set = PlHashSet::new(); let mut unique = Vec::with_capacity(capacity); a.enumerate().for_each(|(idx, val)| { - if set.insert(val) { + if set.insert(val.to_total_ord()) { unique.push(idx as IdxSize) } }); @@ -83,8 +81,9 @@ macro_rules! arg_unique_ca { impl ChunkUnique for ChunkedArray where - T: PolarsIntegerType, - T::Native: Hash + Eq + Ord, + T: PolarsNumericType, + T::Native: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Hash + Eq + Ord, ChunkedArray: IntoSeries + for<'a> ChunkCompare<&'a ChunkedArray, Item = BooleanChunked>, { fn unique(&self) -> PolarsResult { @@ -96,25 +95,23 @@ where IsSorted::Ascending | IsSorted::Descending => { if self.null_count() > 0 { let mut arr = MutablePrimitiveArray::with_capacity(self.len()); - let mut iter = self.into_iter(); - let mut last = None; - if let Some(val) = iter.next() { - last = val; - arr.push(val) - }; + if !self.is_empty() { + let mut iter = self.iter(); + let last = iter.next().unwrap(); + arr.push(last); + let mut last = last.to_total_ord(); - #[allow(clippy::unnecessary_filter_map)] - let to_extend = iter.filter_map(|opt_val| { - if opt_val != last { - last = opt_val; - Some(opt_val) - } else { - None - } - }); + let to_extend = iter.filter(|opt_val| { + let opt_val_tot_ord = opt_val.to_total_ord(); + let out = opt_val_tot_ord != last; + last = opt_val_tot_ord; + out + }); + + arr.extend(to_extend); + } - arr.extend(to_extend); let arr: PrimitiveArray = arr.into(); Ok(ChunkedArray::with_chunk(self.name(), arr)) } else { @@ -142,15 +139,18 @@ where IsSorted::Ascending | IsSorted::Descending => { if self.null_count() > 0 { let mut count = 0; - let mut iter = self.into_iter(); - let mut last = None; - if let Some(val) = iter.next() { - last = val; - count += 1; - }; + if self.is_empty() { + return Ok(count); + } + + let mut iter = self.iter(); + let mut last = iter.next().unwrap().to_total_ord(); + + count += 1; iter.for_each(|opt_val| { + let opt_val = opt_val.to_total_ord(); if opt_val != last { last = opt_val; count += 1; @@ -254,30 +254,6 @@ impl ChunkUnique for BooleanChunked { } } -impl ChunkUnique for Float32Chunked { - fn unique(&self) -> PolarsResult> { - let ca = self.bit_repr_small(); - let ca = ca.unique()?; - Ok(ca._reinterpret_float()) - } - - fn arg_unique(&self) -> PolarsResult { - self.bit_repr_small().arg_unique() - } -} - -impl ChunkUnique for Float64Chunked { - fn unique(&self) -> PolarsResult> { - let ca = self.bit_repr_large(); - let ca = ca.unique()?; - Ok(ca._reinterpret_float()) - } - - fn arg_unique(&self) -> PolarsResult { - self.bit_repr_large().arg_unique() - } -} - #[cfg(test)] mod test { use crate::prelude::*; diff --git a/crates/polars-core/src/chunked_array/ops/zip.rs b/crates/polars-core/src/chunked_array/ops/zip.rs index 8033b4d80f2b1..80b3bcdfd8157 100644 --- a/crates/polars-core/src/chunked_array/ops/zip.rs +++ b/crates/polars-core/src/chunked_array/ops/zip.rs @@ -1,8 +1,7 @@ use arrow::compute::if_then_else::if_then_else; -use arrow::legacy::array::default_arrays::FromData; use crate::prelude::*; -use crate::utils::{align_chunks_ternary, CustomIterTools}; +use crate::utils::align_chunks_ternary; fn ternary_apply(predicate: bool, truthy: T, falsy: T) -> T { if predicate { diff --git a/crates/polars-core/src/chunked_array/random.rs b/crates/polars-core/src/chunked_array/random.rs index 7476183eab09f..18b1117669fca 100644 --- a/crates/polars-core/src/chunked_array/random.rs +++ b/crates/polars-core/src/chunked_array/random.rs @@ -3,12 +3,12 @@ use polars_error::to_compute_err; use rand::distributions::Bernoulli; use rand::prelude::*; use rand::seq::index::IndexVec; -use rand_distr::{Distribution, Normal, Standard, StandardNormal, Uniform}; +use rand_distr::{Normal, Standard, StandardNormal, Uniform}; use crate::prelude::DataType::Float64; use crate::prelude::*; use crate::random::get_global_random_u64; -use crate::utils::{CustomIterTools, NoNull}; +use crate::utils::NoNull; fn create_rand_index_with_replacement(n: usize, len: usize, seed: Option) -> IdxCa { if len == 0 { @@ -194,7 +194,7 @@ impl DataFrame { Some(n) => self.sample_n_literal(n as usize, with_replacement, shuffle, seed), None => { let new_cols = self.columns.iter().map(Series::clear).collect_trusted(); - Ok(DataFrame::new_no_checks(new_cols)) + Ok(unsafe { DataFrame::new_no_checks(new_cols) }) }, } } @@ -239,7 +239,7 @@ impl DataFrame { }, None => { let new_cols = self.columns.iter().map(Series::clear).collect_trusted(); - Ok(DataFrame::new_no_checks(new_cols)) + Ok(unsafe { DataFrame::new_no_checks(new_cols) }) }, } } diff --git a/crates/polars-core/src/chunked_array/temporal/conversion.rs b/crates/polars-core/src/chunked_array/temporal/conversion.rs index 34baa7c7533e9..f54c17d4081e9 100644 --- a/crates/polars-core/src/chunked_array/temporal/conversion.rs +++ b/crates/polars-core/src/chunked_array/temporal/conversion.rs @@ -10,9 +10,7 @@ impl From<&AnyValue<'_>> for NaiveDateTime { fn from(v: &AnyValue) -> Self { match v { #[cfg(feature = "dtype-date")] - AnyValue::Date(v) => { - NaiveDateTime::from_timestamp_opt(*v as i64 * SECONDS_IN_DAY, 0).unwrap() - }, + AnyValue::Date(v) => date32_to_datetime(*v), #[cfg(feature = "dtype-datetime")] AnyValue::Datetime(v, tu, _) => match tu { TimeUnit::Nanoseconds => timestamp_ns_to_datetime(*v), @@ -36,18 +34,18 @@ impl From<&AnyValue<'_>> for NaiveTime { // Used by lazy for literal conversion pub fn datetime_to_timestamp_ns(v: NaiveDateTime) -> i64 { - v.timestamp_nanos_opt().unwrap() + v.and_utc().timestamp_nanos_opt().unwrap() } // Used by lazy for literal conversion pub fn datetime_to_timestamp_ms(v: NaiveDateTime) -> i64 { - v.timestamp_millis() + v.and_utc().timestamp_millis() } // Used by lazy for literal conversion pub fn datetime_to_timestamp_us(v: NaiveDateTime) -> i64 { - let us = v.timestamp() * 1_000_000; - us + v.timestamp_subsec_micros() as i64 + let us = v.and_utc().timestamp() * 1_000_000; + us + v.and_utc().timestamp_subsec_micros() as i64 } pub(crate) fn naive_datetime_to_date(v: NaiveDateTime) -> i32 { diff --git a/crates/polars-core/src/chunked_array/temporal/datetime.rs b/crates/polars-core/src/chunked_array/temporal/datetime.rs index b94c151181a67..bd3e6fae1c479 100644 --- a/crates/polars-core/src/chunked_array/temporal/datetime.rs +++ b/crates/polars-core/src/chunked_array/temporal/datetime.rs @@ -4,16 +4,10 @@ use arrow::temporal_conversions::{ timestamp_ms_to_datetime, timestamp_ns_to_datetime, timestamp_us_to_datetime, }; use chrono::format::{DelayedFormat, StrftimeItems}; -use chrono::NaiveDate; #[cfg(feature = "timezones")] use chrono::TimeZone as TimeZoneTrait; -#[cfg(feature = "timezones")] -use chrono_tz::Tz; -use super::conversion::{datetime_to_timestamp_ms, datetime_to_timestamp_ns}; use super::*; -#[cfg(feature = "timezones")] -use crate::chunked_array::temporal::validate_time_zone; use crate::prelude::DataType::Datetime; use crate::prelude::*; diff --git a/crates/polars-core/src/chunked_array/temporal/mod.rs b/crates/polars-core/src/chunked_array/temporal/mod.rs index f761214f85a61..de58e421dc0fa 100644 --- a/crates/polars-core/src/chunked_array/temporal/mod.rs +++ b/crates/polars-core/src/chunked_array/temporal/mod.rs @@ -26,10 +26,6 @@ pub use self::conversion::*; #[cfg(feature = "timezones")] use crate::prelude::{polars_bail, PolarsResult}; -pub fn unix_time() -> NaiveDateTime { - NaiveDateTime::from_timestamp_opt(0, 0).unwrap() -} - #[cfg(feature = "timezones")] static FIXED_OFFSET_PATTERN: &str = r#"(?x) ^ diff --git a/crates/polars-core/src/chunked_array/trusted_len.rs b/crates/polars-core/src/chunked_array/trusted_len.rs index a241e0432569e..baa473cc07e12 100644 --- a/crates/polars-core/src/chunked_array/trusted_len.rs +++ b/crates/polars-core/src/chunked_array/trusted_len.rs @@ -4,7 +4,7 @@ use arrow::legacy::trusted_len::{FromIteratorReversed, TrustedLenPush}; use crate::chunked_array::upstream_traits::PolarsAsRef; use crate::prelude::*; -use crate::utils::{CustomIterTools, FromTrustedLenIterator, NoNull}; +use crate::utils::{FromTrustedLenIterator, NoNull}; impl FromTrustedLenIterator> for ChunkedArray where @@ -193,7 +193,6 @@ impl FromTrustedLenIterator> for ObjectChunked { #[cfg(test)] mod test { use super::*; - use crate::utils::CustomIterTools; #[test] fn test_reverse_collect() { diff --git a/crates/polars-core/src/chunked_array/upstream_traits.rs b/crates/polars-core/src/chunked_array/upstream_traits.rs index 3975e95414460..ce0fbcf4ad7b2 100644 --- a/crates/polars-core/src/chunked_array/upstream_traits.rs +++ b/crates/polars-core/src/chunked_array/upstream_traits.rs @@ -1,14 +1,10 @@ //! Implementations of upstream traits for [`ChunkedArray`] use std::borrow::{Borrow, Cow}; use std::collections::LinkedList; -use std::iter::FromIterator; use std::marker::PhantomData; -use std::sync::Arc; -use arrow::array::{BooleanArray, PrimitiveArray}; use arrow::bitmap::{Bitmap, MutableBitmap}; use polars_utils::sync::SyncPtr; -use rayon::iter::{FromParallelIterator, IntoParallelIterator}; use rayon::prelude::*; use crate::chunked_array::builder::{ @@ -22,7 +18,7 @@ use crate::chunked_array::object::builder::get_object_type; use crate::chunked_array::object::ObjectArray; use crate::prelude::*; use crate::utils::flatten::flatten_par; -use crate::utils::{get_iter_capacity, CustomIterTools, NoNull}; +use crate::utils::{get_iter_capacity, NoNull}; impl Default for ChunkedArray { fn default() -> Self { diff --git a/crates/polars-core/src/config.rs b/crates/polars-core/src/config.rs index f6cd7fff1b284..dee5e0103e54c 100644 --- a/crates/polars-core/src/config.rs +++ b/crates/polars-core/src/config.rs @@ -54,3 +54,9 @@ pub fn get_rg_prefetch_size() -> usize { // Set it to something big, but not unlimited. .unwrap_or_else(|_| std::cmp::max(get_file_prefetch_size(), 128)) } + +pub fn env_force_async() -> bool { + std::env::var("POLARS_FORCE_ASYNC") + .map(|value| value == "1") + .unwrap_or_default() +} diff --git a/crates/polars-core/src/datatypes/_serde.rs b/crates/polars-core/src/datatypes/_serde.rs index 69642e7497961..922ba5b95b3fa 100644 --- a/crates/polars-core/src/datatypes/_serde.rs +++ b/crates/polars-core/src/datatypes/_serde.rs @@ -4,8 +4,8 @@ //! We could use [serde_1712](https://github.com/serde-rs/serde/issues/1712), but that gave problems caused by //! [rust_96956](https://github.com/rust-lang/rust/issues/96956), so we make a dummy type without static -use serde::de::{SeqAccess, Visitor}; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use serde::de::SeqAccess; +use serde::{Deserialize, Serialize}; use super::*; diff --git a/crates/polars-core/src/datatypes/aliases.rs b/crates/polars-core/src/datatypes/aliases.rs index 42ecbd018bdf0..b4ff2e9075c26 100644 --- a/crates/polars-core/src/datatypes/aliases.rs +++ b/crates/polars-core/src/datatypes/aliases.rs @@ -1,4 +1,4 @@ -pub use arrow::legacy::index::{IdxArr, IdxSize}; +pub use arrow::legacy::index::IdxArr; pub use polars_utils::aliases::{InitHashMaps, PlHashMap, PlHashSet, PlIndexMap, PlIndexSet}; use super::*; diff --git a/crates/polars-core/src/datatypes/any_value.rs b/crates/polars-core/src/datatypes/any_value.rs index 3f47896d5d871..523b7a9939d48 100644 --- a/crates/polars-core/src/datatypes/any_value.rs +++ b/crates/polars-core/src/datatypes/any_value.rs @@ -1,15 +1,12 @@ #[cfg(feature = "dtype-struct")] use arrow::legacy::trusted_len::TrustedLenPush; -#[cfg(feature = "dtype-date")] -use arrow::temporal_conversions::{ - timestamp_ms_to_datetime, timestamp_ns_to_datetime, timestamp_us_to_datetime, -}; use arrow::types::PrimitiveType; use polars_utils::format_smartstring; #[cfg(feature = "dtype-struct")] use polars_utils::slice::GetSaferUnchecked; #[cfg(feature = "dtype-categorical")] use polars_utils::sync::SyncPtr; +use polars_utils::total_ord::ToTotalOrd; use polars_utils::unwrap::UnwrapUncheckedRelease; use super::*; @@ -490,130 +487,126 @@ impl<'a> AnyValue<'a> { } } + /// Cast `AnyValue` to the provided data type and return a new `AnyValue` with type `dtype`, + /// if possible. + /// pub fn strict_cast(&self, dtype: &'a DataType) -> PolarsResult> { - fn cast_to_numeric<'a>(av: &AnyValue, dtype: &'a DataType) -> PolarsResult> { - let out = match dtype { - DataType::UInt8 => AnyValue::UInt8(av.try_extract::()?), - DataType::UInt16 => AnyValue::UInt16(av.try_extract::()?), - DataType::UInt32 => AnyValue::UInt32(av.try_extract::()?), - DataType::UInt64 => AnyValue::UInt64(av.try_extract::()?), - DataType::Int8 => AnyValue::Int8(av.try_extract::()?), - DataType::Int16 => AnyValue::Int16(av.try_extract::()?), - DataType::Int32 => AnyValue::Int32(av.try_extract::()?), - DataType::Int64 => AnyValue::Int64(av.try_extract::()?), - DataType::Float32 => AnyValue::Float32(av.try_extract::()?), - DataType::Float64 => AnyValue::Float64(av.try_extract::()?), - _ => { - polars_bail!(ComputeError: "cannot cast any-value {:?} to dtype '{}'", av, dtype) - }, - }; - Ok(out) - } + let new_av = match (self, dtype) { + // to numeric + (av, DataType::UInt8) => AnyValue::UInt8(av.try_extract::()?), + (av, DataType::UInt16) => AnyValue::UInt16(av.try_extract::()?), + (av, DataType::UInt32) => AnyValue::UInt32(av.try_extract::()?), + (av, DataType::UInt64) => AnyValue::UInt64(av.try_extract::()?), + (av, DataType::Int8) => AnyValue::Int8(av.try_extract::()?), + (av, DataType::Int16) => AnyValue::Int16(av.try_extract::()?), + (av, DataType::Int32) => AnyValue::Int32(av.try_extract::()?), + (av, DataType::Int64) => AnyValue::Int64(av.try_extract::()?), + (av, DataType::Float32) => AnyValue::Float32(av.try_extract::()?), + (av, DataType::Float64) => AnyValue::Float64(av.try_extract::()?), - fn cast_to_boolean<'a>(av: &AnyValue) -> PolarsResult> { - let out = match av { - AnyValue::UInt8(v) => AnyValue::Boolean(*v != u8::default()), - AnyValue::UInt16(v) => AnyValue::Boolean(*v != u16::default()), - AnyValue::UInt32(v) => AnyValue::Boolean(*v != u32::default()), - AnyValue::UInt64(v) => AnyValue::Boolean(*v != u64::default()), - AnyValue::Int8(v) => AnyValue::Boolean(*v != i8::default()), - AnyValue::Int16(v) => AnyValue::Boolean(*v != i16::default()), - AnyValue::Int32(v) => AnyValue::Boolean(*v != i32::default()), - AnyValue::Int64(v) => AnyValue::Boolean(*v != i64::default()), - AnyValue::Float32(v) => AnyValue::Boolean(*v != f32::default()), - AnyValue::Float64(v) => AnyValue::Boolean(*v != f64::default()), - _ => { - polars_bail!(ComputeError: "cannot cast any-value {:?} to boolean", av) - }, - }; - Ok(out) - } + // to boolean + (AnyValue::UInt8(v), DataType::Boolean) => AnyValue::Boolean(*v != u8::default()), + (AnyValue::UInt16(v), DataType::Boolean) => AnyValue::Boolean(*v != u16::default()), + (AnyValue::UInt32(v), DataType::Boolean) => AnyValue::Boolean(*v != u32::default()), + (AnyValue::UInt64(v), DataType::Boolean) => AnyValue::Boolean(*v != u64::default()), + (AnyValue::Int8(v), DataType::Boolean) => AnyValue::Boolean(*v != i8::default()), + (AnyValue::Int16(v), DataType::Boolean) => AnyValue::Boolean(*v != i16::default()), + (AnyValue::Int32(v), DataType::Boolean) => AnyValue::Boolean(*v != i32::default()), + (AnyValue::Int64(v), DataType::Boolean) => AnyValue::Boolean(*v != i64::default()), + (AnyValue::Float32(v), DataType::Boolean) => AnyValue::Boolean(*v != f32::default()), + (AnyValue::Float64(v), DataType::Boolean) => AnyValue::Boolean(*v != f64::default()), - let new_av = match self { - _ if (self.is_boolean() | self.is_numeric()) => match dtype { - #[cfg(feature = "dtype-date")] - DataType::Date => AnyValue::Date(self.try_extract::()?), - #[cfg(feature = "dtype-datetime")] - DataType::Datetime(tu, tz) => { - AnyValue::Datetime(self.try_extract::()?, *tu, tz) - }, - #[cfg(feature = "dtype-duration")] - DataType::Duration(tu) => AnyValue::Duration(self.try_extract::()?, *tu), - #[cfg(feature = "dtype-time")] - DataType::Time => AnyValue::Time(self.try_extract::()?), - DataType::String => { - AnyValue::StringOwned(format_smartstring!("{}", self.try_extract::()?)) - }, - DataType::Boolean => return cast_to_boolean(self), - _ => return cast_to_numeric(self, dtype), + // to string + (av, DataType::String) => { + AnyValue::StringOwned(format_smartstring!("{}", av.try_extract::()?)) }, + + // to binary + (AnyValue::String(v), DataType::Binary) => AnyValue::Binary(v.as_bytes()), + + // to datetime #[cfg(feature = "dtype-datetime")] - AnyValue::Datetime(v, tu, None) => match dtype { - #[cfg(feature = "dtype-date")] - // Datetime to Date - DataType::Date => { - let convert = match tu { - TimeUnit::Nanoseconds => timestamp_ns_to_datetime, - TimeUnit::Microseconds => timestamp_us_to_datetime, - TimeUnit::Milliseconds => timestamp_ms_to_datetime, - }; - let ndt = convert(*v); - let date_value = naive_datetime_to_date(ndt); - AnyValue::Date(date_value) + (av, DataType::Datetime(tu, tz)) if av.is_numeric() => { + AnyValue::Datetime(av.try_extract::()?, *tu, tz) + }, + #[cfg(all(feature = "dtype-datetime", feature = "dtype-date"))] + (AnyValue::Date(v), DataType::Datetime(tu, _)) => AnyValue::Datetime( + match tu { + TimeUnit::Nanoseconds => (*v as i64) * NS_IN_DAY, + TimeUnit::Microseconds => (*v as i64) * US_IN_DAY, + TimeUnit::Milliseconds => (*v as i64) * MS_IN_DAY, }, - #[cfg(feature = "dtype-time")] - // Datetime to Time - DataType::Time => { - let ns_since_midnight = match tu { - TimeUnit::Nanoseconds => *v % NS_IN_DAY, - TimeUnit::Microseconds => (*v % US_IN_DAY) * 1_000i64, - TimeUnit::Milliseconds => (*v % MS_IN_DAY) * 1_000_000i64, - }; - AnyValue::Time(ns_since_midnight) + *tu, + &None, + ), + #[cfg(feature = "dtype-datetime")] + (AnyValue::Datetime(v, tu, _), DataType::Datetime(tu_r, tz_r)) => AnyValue::Datetime( + match (tu, tu_r) { + (TimeUnit::Nanoseconds, TimeUnit::Microseconds) => *v / 1_000i64, + (TimeUnit::Nanoseconds, TimeUnit::Milliseconds) => *v / 1_000_000i64, + (TimeUnit::Microseconds, TimeUnit::Nanoseconds) => *v * 1_000i64, + (TimeUnit::Microseconds, TimeUnit::Milliseconds) => *v / 1_000i64, + (TimeUnit::Milliseconds, TimeUnit::Microseconds) => *v * 1_000i64, + (TimeUnit::Milliseconds, TimeUnit::Nanoseconds) => *v * 1_000_000i64, + _ => *v, }, - _ => return cast_to_numeric(self, dtype), - }, + *tu_r, + tz_r, + ), + + // to date + #[cfg(feature = "dtype-date")] + (av, DataType::Date) if av.is_numeric() => AnyValue::Date(av.try_extract::()?), + #[cfg(all(feature = "dtype-date", feature = "dtype-datetime"))] + (AnyValue::Datetime(v, tu, _), DataType::Date) => AnyValue::Date(match tu { + TimeUnit::Nanoseconds => *v / NS_IN_DAY, + TimeUnit::Microseconds => *v / US_IN_DAY, + TimeUnit::Milliseconds => *v / MS_IN_DAY, + } as i32), + + // to time + #[cfg(feature = "dtype-time")] + (av, DataType::Time) if av.is_numeric() => AnyValue::Time(av.try_extract::()?), + #[cfg(all(feature = "dtype-time", feature = "dtype-datetime"))] + (AnyValue::Datetime(v, tu, _), DataType::Time) => AnyValue::Time(match tu { + TimeUnit::Nanoseconds => *v % NS_IN_DAY, + TimeUnit::Microseconds => (*v % US_IN_DAY) * 1_000i64, + TimeUnit::Milliseconds => (*v % MS_IN_DAY) * 1_000_000i64, + }), + + // to duration #[cfg(feature = "dtype-duration")] - AnyValue::Duration(v, _) => match dtype { - DataType::Time | DataType::Date | DataType::Datetime(_, _) => { - polars_bail!(ComputeError: "cannot cast any-value {:?} to dtype '{}'", v, dtype) - }, - _ => return cast_to_numeric(self, dtype), + (av, DataType::Duration(tu)) if av.is_numeric() => { + AnyValue::Duration(av.try_extract::()?, *tu) }, - #[cfg(feature = "dtype-time")] - AnyValue::Time(v) => match dtype { - #[cfg(feature = "dtype-duration")] - // Time to Duration - DataType::Duration(tu) => { - let duration_value = match tu { - TimeUnit::Nanoseconds => *v, - TimeUnit::Microseconds => *v / 1_000i64, - TimeUnit::Milliseconds => *v / 1_000_000i64, - }; - AnyValue::Duration(duration_value, *tu) + #[cfg(all(feature = "dtype-duration", feature = "dtype-time"))] + (AnyValue::Time(v), DataType::Duration(tu)) => AnyValue::Duration( + match *tu { + TimeUnit::Nanoseconds => *v, + TimeUnit::Microseconds => *v / 1_000i64, + TimeUnit::Milliseconds => *v / 1_000_000i64, }, - _ => return cast_to_numeric(self, dtype), - }, - #[cfg(feature = "dtype-date")] - AnyValue::Date(v) => match dtype { - #[cfg(feature = "dtype-datetime")] - // Date to Datetime - DataType::Datetime(tu, None) => { - let ndt = arrow::temporal_conversions::date32_to_datetime(*v); - let func = match tu { - TimeUnit::Nanoseconds => datetime_to_timestamp_ns, - TimeUnit::Microseconds => datetime_to_timestamp_us, - TimeUnit::Milliseconds => datetime_to_timestamp_ms, - }; - let value = func(ndt); - AnyValue::Datetime(value, *tu, &None) + *tu, + ), + #[cfg(feature = "dtype-duration")] + (AnyValue::Duration(v, tu), DataType::Duration(tu_r)) => AnyValue::Duration( + match (tu, tu_r) { + (_, _) if tu == tu_r => *v, + (TimeUnit::Nanoseconds, TimeUnit::Microseconds) => *v / 1_000i64, + (TimeUnit::Nanoseconds, TimeUnit::Milliseconds) => *v / 1_000_000i64, + (TimeUnit::Microseconds, TimeUnit::Nanoseconds) => *v * 1_000i64, + (TimeUnit::Microseconds, TimeUnit::Milliseconds) => *v / 1_000i64, + (TimeUnit::Milliseconds, TimeUnit::Microseconds) => *v * 1_000i64, + (TimeUnit::Milliseconds, TimeUnit::Nanoseconds) => *v * 1_000_000i64, + _ => *v, }, - _ => return cast_to_numeric(self, dtype), - }, - AnyValue::String(s) if dtype == &DataType::Binary => AnyValue::Binary(s.as_bytes()), - _ => { - polars_bail!(ComputeError: "cannot cast any-value '{:?}' to '{:?}'", self.dtype(), dtype) - }, + *tu_r, + ), + + // to self + (av, dtype) if av.dtype() == *dtype => self.clone(), + + av => polars_bail!(ComputeError: "cannot cast any-value {:?} to dtype '{}'", av, dtype), }; Ok(new_av) } @@ -893,8 +886,8 @@ impl AnyValue<'_> { (Int16(l), Int16(r)) => *l == *r, (Int32(l), Int32(r)) => *l == *r, (Int64(l), Int64(r)) => *l == *r, - (Float32(l), Float32(r)) => *l == *r, - (Float64(l), Float64(r)) => *l == *r, + (Float32(l), Float32(r)) => l.to_total_ord() == r.to_total_ord(), + (Float64(l), Float64(r)) => l.to_total_ord() == r.to_total_ord(), (String(l), String(r)) => l == r, (String(l), StringOwned(r)) => l == r, (StringOwned(l), String(r)) => l == r, @@ -978,8 +971,8 @@ impl PartialOrd for AnyValue<'_> { (Int16(l), Int16(r)) => l.partial_cmp(r), (Int32(l), Int32(r)) => l.partial_cmp(r), (Int64(l), Int64(r)) => l.partial_cmp(r), - (Float32(l), Float32(r)) => l.partial_cmp(r), - (Float64(l), Float64(r)) => l.partial_cmp(r), + (Float32(l), Float32(r)) => l.to_total_ord().partial_cmp(&r.to_total_ord()), + (Float64(l), Float64(r)) => l.to_total_ord().partial_cmp(&r.to_total_ord()), (String(l), String(r)) => l.partial_cmp(*r), (Binary(l), Binary(r)) => l.partial_cmp(*r), _ => None, diff --git a/crates/polars-core/src/datatypes/dtype.rs b/crates/polars-core/src/datatypes/dtype.rs index 6350a8fc9a45f..7e5947301db47 100644 --- a/crates/polars-core/src/datatypes/dtype.rs +++ b/crates/polars-core/src/datatypes/dtype.rs @@ -1,6 +1,4 @@ use std::collections::BTreeMap; -use std::convert::Into; -use std::string::ToString; use super::*; #[cfg(feature = "object")] @@ -82,7 +80,14 @@ impl PartialEq for DataType { match (self, other) { // Don't include rev maps in comparisons #[cfg(feature = "dtype-categorical")] - (Categorical(_, _), Categorical(_, _)) | (Enum(_, _), Enum(_, _)) => true, + (Categorical(_, _), Categorical(_, _)) => true, + #[cfg(feature = "dtype-categorical")] + // None means select all Enum dtypes. This is for operation `pl.col(pl.Enum)` + (Enum(None, _), Enum(_, _)) | (Enum(_, _), Enum(None, _)) => true, + #[cfg(feature = "dtype-categorical")] + (Enum(Some(cat_lhs), _), Enum(Some(cat_rhs), _)) => { + cat_lhs.get_categories() == cat_rhs.get_categories() + }, (Datetime(tu_l, tz_l), Datetime(tu_r, tz_r)) => tu_l == tu_r && tz_l == tz_r, (List(left_inner), List(right_inner)) => left_inner == right_inner, #[cfg(feature = "dtype-duration")] @@ -211,6 +216,27 @@ impl DataType { matches!(self, DataType::Boolean) } + /// Check if this [`DataType`] is a list + pub fn is_list(&self) -> bool { + matches!(self, DataType::List(_)) + } + + pub fn is_nested(&self) -> bool { + self.is_list() || self.is_struct() + } + + /// Check if this [`DataType`] is a struct + pub fn is_struct(&self) -> bool { + #[cfg(feature = "dtype-struct")] + { + matches!(self, DataType::Struct(_)) + } + #[cfg(not(feature = "dtype-struct"))] + { + false + } + } + pub fn is_binary(&self) -> bool { matches!(self, DataType::Binary) } @@ -239,6 +265,7 @@ impl DataType { let phys = self.to_physical(); (phys.is_numeric() + || self.is_decimal() || matches!( phys, DataType::Binary | DataType::String | DataType::Boolean diff --git a/crates/polars-core/src/fmt.rs b/crates/polars-core/src/fmt.rs index 148f70ac854a9..1a1ab80f342ac 100644 --- a/crates/polars-core/src/fmt.rs +++ b/crates/polars-core/src/fmt.rs @@ -25,7 +25,10 @@ use num_traits::{Num, NumCast}; use crate::config::*; use crate::prelude::*; -const LIMIT: usize = 25; + +// Note: see https://github.com/pola-rs/polars/pull/13699 for the rationale +// behind choosing 10 as the default value for default number of rows displayed +const LIMIT: usize = 10; #[derive(Copy, Clone)] #[repr(u8)] @@ -130,19 +133,18 @@ macro_rules! format_array { }; Ok(()) }; - if (limit == 0 && $a.len() > 0) || ($a.len() > limit + 1) { - if limit > 0 { - for i in 0..std::cmp::max((limit / 2), 1) { - let v = $a.get_any_value(i).unwrap(); - write_fn(v, $f)?; - } + if $a.len() > limit { + let half = limit / 2; + let rest = limit % 2; + + for i in 0..(half + rest) { + let v = $a.get_any_value(i).unwrap(); + write_fn(v, $f)?; } write!($f, "\t…\n")?; - if limit > 1 { - for i in ($a.len() - (limit + 1) / 2)..$a.len() { - let v = $a.get_any_value(i).unwrap(); - write_fn(v, $f)?; - } + for i in ($a.len() - half)..$a.len() { + let v = $a.get_any_value(i).unwrap(); + write_fn(v, $f)?; } } else { for i in 0..$a.len() { @@ -524,9 +526,7 @@ impl Display for DataFrame { .as_deref() .unwrap_or("") .parse() - // Note: see "https://github.com/pola-rs/polars/pull/13699" for - // the rationale behind choosing 10 as the default value ;) - .map_or(10, |n: i64| if n < 0 { height } else { n as usize }); + .map_or(LIMIT, |n: i64| if n < 0 { height } else { n as usize }); let (n_first, n_last) = if self.width() > max_n_cols { ((max_n_cols + 1) / 2, max_n_cols / 2) @@ -588,11 +588,15 @@ impl Display for DataFrame { let mut max_elem_lengths: Vec = vec![0; n_tbl_cols]; if max_n_rows > 0 { - if height > max_n_rows + 1 { - // Truncate the table if we have more rows than the configured maximum - // number of rows plus the single row which would contain "…". + if height > max_n_rows { + // Truncate the table if we have more rows than the + // configured maximum number of rows let mut rows = Vec::with_capacity(std::cmp::max(max_n_rows, 2)); - for i in 0..std::cmp::max(max_n_rows / 2, 1) { + + let half = max_n_rows / 2; + let rest = max_n_rows % 2; + + for i in 0..(half + rest) { let row = self .columns .iter() @@ -606,23 +610,16 @@ impl Display for DataFrame { } let dots = rows[0].iter().map(|_| "…".to_string()).collect(); rows.push(dots); - if max_n_rows > 1 { - for i in (height - (max_n_rows + 1) / 2)..height { - let row = self - .columns - .iter() - .map(|s| s.str_value(i).unwrap()) - .collect(); + for i in (height - half)..height { + let row = self + .columns + .iter() + .map(|s| s.str_value(i).unwrap()) + .collect(); - let row_strings = prepare_row( - row, - n_first, - n_last, - str_truncate, - &mut max_elem_lengths, - ); - rows.push(row_strings); - } + let row_strings = + prepare_row(row, n_first, n_last, str_truncate, &mut max_elem_lengths); + rows.push(row_strings); } table.add_rows(rows); } else { diff --git a/crates/polars-core/src/frame/arithmetic.rs b/crates/polars-core/src/frame/arithmetic.rs index be60fb04346ff..0082ecc4534a6 100644 --- a/crates/polars-core/src/frame/arithmetic.rs +++ b/crates/polars-core/src/frame/arithmetic.rs @@ -21,7 +21,7 @@ macro_rules! impl_arithmetic { let cols = POOL.install(|| {$self.columns.par_iter().map(|s| { Ok(&s.cast(&st)? $operand &rhs) }).collect::>()})?; - Ok(DataFrame::new_no_checks(cols)) + Ok(unsafe { DataFrame::new_no_checks(cols) }) }} } @@ -113,7 +113,7 @@ impl DataFrame { ) -> PolarsResult { let max_len = std::cmp::max(self.height(), other.height()); let max_width = std::cmp::max(self.width(), other.width()); - let mut cols = self + let cols = self .get_columns() .par_iter() .zip(other.get_columns().par_iter()) @@ -133,8 +133,8 @@ impl DataFrame { }; f(&l, &r) - }) - .collect::>>()?; + }); + let mut cols = POOL.install(|| cols.collect::>>())?; let col_len = cols.len(); if col_len < max_width { diff --git a/crates/polars-core/src/frame/explode.rs b/crates/polars-core/src/frame/explode.rs index 66c9e2476dad5..51fe294aa3fa4 100644 --- a/crates/polars-core/src/frame/explode.rs +++ b/crates/polars-core/src/frame/explode.rs @@ -275,7 +275,7 @@ impl DataFrame { out.push(variable_col); out.push(value_col); - return Ok(DataFrame::new_no_checks(out)); + return Ok(unsafe { DataFrame::new_no_checks(out) }); } let id_vars_set = PlHashSet::from_iter(id_vars.iter().map(|s| s.as_str())); @@ -354,7 +354,6 @@ impl DataFrame { #[cfg(test)] mod test { - use crate::frame::explode::MeltArgs; use crate::prelude::*; #[test] diff --git a/crates/polars-core/src/frame/from.rs b/crates/polars-core/src/frame/from.rs index 4845edb79a161..72172ec7e7364 100644 --- a/crates/polars-core/src/frame/from.rs +++ b/crates/polars-core/src/frame/from.rs @@ -1,5 +1,3 @@ -use arrow::array::StructArray; - use crate::prelude::*; impl TryFrom for DataFrame { @@ -37,7 +35,7 @@ impl From<&Schema> for DataFrame { .iter() .map(|(name, dtype)| Series::new_empty(name, dtype)) .collect(); - DataFrame::new_no_checks(cols) + unsafe { DataFrame::new_no_checks(cols) } } } @@ -48,6 +46,6 @@ impl From<&ArrowSchema> for DataFrame { .iter() .map(|fld| Series::new_empty(fld.name.as_str(), &(fld.data_type().into()))) .collect(); - DataFrame::new_no_checks(cols) + unsafe { DataFrame::new_no_checks(cols) } } } diff --git a/crates/polars-core/src/frame/group_by/aggregations/mod.rs b/crates/polars-core/src/frame/group_by/aggregations/mod.rs index 0addcb2b56d30..46dc3261d6801 100644 --- a/crates/polars-core/src/frame/group_by/aggregations/mod.rs +++ b/crates/polars-core/src/frame/group_by/aggregations/mod.rs @@ -12,7 +12,6 @@ use arrow::legacy::kernels::rolling::no_nulls::{ MaxWindow, MeanWindow, MinWindow, QuantileWindow, RollingAggWindowNoNulls, SumWindow, VarWindow, }; use arrow::legacy::kernels::rolling::nulls::RollingAggWindowNulls; -use arrow::legacy::kernels::rolling::{RollingQuantileParams, RollingVarParams}; use arrow::legacy::kernels::take_agg::*; use arrow::legacy::prelude::QuantileInterpolOptions; use arrow::legacy::trusted_len::TrustedLenPush; @@ -52,7 +51,10 @@ pub fn _use_rolling_kernels(groups: &GroupsSlice, chunks: &[ArrayRef]) -> bool { let [first_offset, first_len] = groups[0]; let second_offset = groups[1][0]; - second_offset < (first_offset + first_len) && chunks.len() == 1 + second_offset >= first_offset // Prevent false positive from regular group-by that has out of order slices. + // Rolling group-by is expected to have monotonically increasing slices. + && second_offset < (first_offset + first_len) + && chunks.len() == 1 }, } } @@ -140,7 +142,7 @@ where None } else { // SAFETY: we are in bounds. - Some(unsafe { agg_window.update(start as usize, end as usize) }) + unsafe { agg_window.update(start as usize, end as usize) } } }) .collect::>() @@ -797,7 +799,13 @@ where debug_assert!(len <= self.len() as IdxSize); match len { 0 => None, - 1 => NumCast::from(0), + 1 => { + if ddof == 0 { + NumCast::from(0) + } else { + None + } + }, _ => { let arr_group = _slice_from_offsets(self, first, len); arr_group.var(ddof).map(|flt| NumCast::from(flt).unwrap()) @@ -859,7 +867,13 @@ where debug_assert!(len <= self.len() as IdxSize); match len { 0 => None, - 1 => NumCast::from(0), + 1 => { + if ddof == 0 { + NumCast::from(0) + } else { + None + } + }, _ => { let arr_group = _slice_from_offsets(self, first, len); arr_group.std(ddof).map(|flt| NumCast::from(flt).unwrap()) @@ -1010,7 +1024,13 @@ where debug_assert!(first + len <= self.len() as IdxSize); match len { 0 => None, - 1 => NumCast::from(0), + 1 => { + if ddof == 0 { + NumCast::from(0) + } else { + None + } + }, _ => { let arr_group = _slice_from_offsets(self, first, len); arr_group.var(ddof) @@ -1052,7 +1072,13 @@ where debug_assert!(first + len <= self.len() as IdxSize); match len { 0 => None, - 1 => NumCast::from(0), + 1 => { + if ddof == 0 { + NumCast::from(0) + } else { + None + } + }, _ => { let arr_group = _slice_from_offsets(self, first, len); arr_group.std(ddof) diff --git a/crates/polars-core/src/frame/group_by/hashing.rs b/crates/polars-core/src/frame/group_by/hashing.rs index 796b5c2f33d1d..b3e85e5dacb5b 100644 --- a/crates/polars-core/src/frame/group_by/hashing.rs +++ b/crates/polars-core/src/frame/group_by/hashing.rs @@ -6,18 +6,14 @@ use polars_utils::hashing::{hash_to_partition, DirtyHash}; use polars_utils::idx_vec::IdxVec; use polars_utils::iter::EnumerateIdxTrait; use polars_utils::sync::SyncPtr; +use polars_utils::total_ord::{ToTotalOrd, TotalHash}; use polars_utils::unitvec; use rayon::prelude::*; -use super::GroupsProxy; -use crate::datatypes::PlHashMap; -use crate::frame::group_by::{GroupsIdx, IdxItem}; -use crate::hashing::{ - _df_rows_to_hashes_threaded_vertical, series_to_hashes, IdBuildHasher, IdxHash, *, -}; +use crate::hashing::*; use crate::prelude::compare_inner::TotalEqInner; use crate::prelude::*; -use crate::utils::{flatten, split_df, CustomIterTools}; +use crate::utils::{flatten, split_df}; use crate::POOL; fn get_init_size() -> usize { @@ -144,12 +140,15 @@ fn finish_group_order_vecs( pub(crate) fn group_by(a: impl Iterator, sorted: bool) -> GroupsProxy where - T: Hash + Eq, + T: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Hash + Eq, { let init_size = get_init_size(); - let mut hash_tbl: PlHashMap = PlHashMap::with_capacity(init_size); + let mut hash_tbl: PlHashMap = + PlHashMap::with_capacity(init_size); let mut cnt = 0; a.for_each(|k| { + let k = k.to_total_ord(); let idx = cnt; cnt += 1; let entry = hash_tbl.entry(k); @@ -188,7 +187,8 @@ pub(crate) fn group_by_threaded_slice( sorted: bool, ) -> GroupsProxy where - T: Send + Hash + Eq + Sync + Copy + DirtyHash, + T: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Send + Hash + Eq + Sync + Copy + DirtyHash, IntoSlice: AsRef<[T]> + Send + Sync, { let init_size = get_init_size(); @@ -200,7 +200,7 @@ where (0..n_partitions) .into_par_iter() .map(|thread_no| { - let mut hash_tbl: PlHashMap = + let mut hash_tbl: PlHashMap = PlHashMap::with_capacity(init_size); let mut offset = 0; @@ -211,18 +211,19 @@ where let mut cnt = 0; keys.iter().for_each(|k| { + let k = k.to_total_ord(); let idx = cnt + offset; cnt += 1; if thread_no == hash_to_partition(k.dirty_hash(), n_partitions) { let hash = hasher.hash_one(k); - let entry = hash_tbl.raw_entry_mut().from_key_hashed_nocheck(hash, k); + let entry = hash_tbl.raw_entry_mut().from_key_hashed_nocheck(hash, &k); match entry { RawEntryMut::Vacant(entry) => { let tuples = unitvec![idx]; - entry.insert_with_hasher(hash, *k, (idx, tuples), |k| { - hasher.hash_one(k) + entry.insert_with_hasher(hash, k, (idx, tuples), |k| { + hasher.hash_one(*k) }); }, RawEntryMut::Occupied(mut entry) => { @@ -252,7 +253,8 @@ pub(crate) fn group_by_threaded_iter( where I: IntoIterator + Send + Sync + Clone, I::IntoIter: ExactSizeIterator, - T: Send + Hash + Eq + Sync + Copy + DirtyHash, + T: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Send + Hash + Eq + Sync + Copy + DirtyHash, { let init_size = get_init_size(); @@ -263,7 +265,7 @@ where (0..n_partitions) .into_par_iter() .map(|thread_no| { - let mut hash_tbl: PlHashMap = + let mut hash_tbl: PlHashMap = PlHashMap::with_capacity(init_size); let mut offset = 0; @@ -274,6 +276,7 @@ where let mut cnt = 0; keys.for_each(|k| { + let k = k.to_total_ord(); let idx = cnt + offset; cnt += 1; @@ -285,7 +288,7 @@ where RawEntryMut::Vacant(entry) => { let tuples = unitvec![idx]; entry.insert_with_hasher(hash, k, (idx, tuples), |k| { - hasher.hash_one(k) + hasher.hash_one(*k) }); }, RawEntryMut::Occupied(mut entry) => { diff --git a/crates/polars-core/src/frame/group_by/into_groups.rs b/crates/polars-core/src/frame/group_by/into_groups.rs index 529b9776e8437..fe2fd5a493e57 100644 --- a/crates/polars-core/src/frame/group_by/into_groups.rs +++ b/crates/polars-core/src/frame/group_by/into_groups.rs @@ -1,7 +1,7 @@ #[cfg(feature = "group_by_list")] use arrow::legacy::kernels::list_bytes_iter::numeric_list_bytes_iter; use arrow::legacy::kernels::sort_partition::{create_clean_partitions, partition_to_groups}; -use arrow::legacy::prelude::*; +use polars_utils::total_ord::{ToTotalOrd, TotalHash}; use super::*; use crate::config::verbose; @@ -12,7 +12,7 @@ use crate::utils::flatten::flatten_par; pub trait IntoGroupsProxy { /// Create the tuples need for a group_by operation. /// * The first value in the tuple is the first index of the group. - /// * The second value in the tuple is are the indexes of the groups including the first value. + /// * The second value in the tuple is the indexes of the groups including the first value. fn group_tuples(&self, _multithreaded: bool, _sorted: bool) -> PolarsResult { unimplemented!() } @@ -25,9 +25,9 @@ fn group_multithreaded(ca: &ChunkedArray) -> bool { fn num_groups_proxy(ca: &ChunkedArray, multithreaded: bool, sorted: bool) -> GroupsProxy where - T: PolarsIntegerType, - T::Native: Hash + Eq + Send + DirtyHash, - Option: DirtyHash, + T: PolarsNumericType, + T::Native: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash, { if multithreaded && group_multithreaded(ca) { let n_partitions = _set_partition_size(); @@ -93,35 +93,31 @@ where let n_parts = parts.len(); let first_ptr = &values[0] as *const T::Native as usize; - let groups = POOL - .install(|| { - parts.par_iter().enumerate().map(|(i, part)| { - // we go via usize as *const is not send - let first_ptr = first_ptr as *const T::Native; - - let part_first_ptr = &part[0] as *const T::Native; - let mut offset = - unsafe { part_first_ptr.offset_from(first_ptr) } as IdxSize; - - // nulls first: only add the nulls at the first partition - if nulls_first && i == 0 { - partition_to_groups(part, null_count as IdxSize, true, offset) - } - // nulls last: only compute at the last partition - else if !nulls_first && i == n_parts - 1 { - partition_to_groups(part, null_count as IdxSize, false, offset) - } - // other partitions - else { - if nulls_first { - offset += null_count as IdxSize; - }; + let groups = parts.par_iter().enumerate().map(|(i, part)| { + // we go via usize as *const is not send + let first_ptr = first_ptr as *const T::Native; - partition_to_groups(part, 0, false, offset) - } - }) - }) - .collect::>(); + let part_first_ptr = &part[0] as *const T::Native; + let mut offset = unsafe { part_first_ptr.offset_from(first_ptr) } as IdxSize; + + // nulls first: only add the nulls at the first partition + if nulls_first && i == 0 { + partition_to_groups(part, null_count as IdxSize, true, offset) + } + // nulls last: only compute at the last partition + else if !nulls_first && i == n_parts - 1 { + partition_to_groups(part, null_count as IdxSize, false, offset) + } + // other partitions + else { + if nulls_first { + offset += null_count as IdxSize; + }; + + partition_to_groups(part, 0, false, offset) + } + }); + let groups = POOL.install(|| groups.collect::>()); flatten_par(&groups) } else { partition_to_groups(values, null_count as IdxSize, nulls_first, 0) @@ -167,14 +163,36 @@ where }; num_groups_proxy(ca, multithreaded, sorted) }, - DataType::Int64 | DataType::Float64 => { + DataType::Int64 => { let ca = self.bit_repr_large(); num_groups_proxy(&ca, multithreaded, sorted) }, - DataType::Int32 | DataType::Float32 => { + DataType::Int32 => { let ca = self.bit_repr_small(); num_groups_proxy(&ca, multithreaded, sorted) }, + DataType::Float64 => { + // convince the compiler that we are this type. + let ca: &Float64Chunked = unsafe { + &*(self as *const ChunkedArray as *const ChunkedArray) + }; + num_groups_proxy(ca, multithreaded, sorted) + }, + DataType::Float32 => { + // convince the compiler that we are this type. + let ca: &Float32Chunked = unsafe { + &*(self as *const ChunkedArray as *const ChunkedArray) + }; + num_groups_proxy(ca, multithreaded, sorted) + }, + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(_, _) => { + // convince the compiler that we are this type. + let ca: &Int128Chunked = unsafe { + &*(self as *const ChunkedArray as *const ChunkedArray) + }; + num_groups_proxy(ca, multithreaded, sorted) + }, #[cfg(all(feature = "performant", feature = "dtype-i8", feature = "dtype-u8"))] DataType::Int8 => { // convince the compiler that we are this type. diff --git a/crates/polars-core/src/frame/group_by/mod.rs b/crates/polars-core/src/frame/group_by/mod.rs index 339432b07edd0..75df1e198c50c 100644 --- a/crates/polars-core/src/frame/group_by/mod.rs +++ b/crates/polars-core/src/frame/group_by/mod.rs @@ -2,7 +2,6 @@ use std::fmt::{Debug, Display, Formatter}; use std::hash::Hash; use ahash::RandomState; -use arrow::legacy::prelude::QuantileInterpolOptions; use num_traits::NumCast; use polars_utils::hashing::{BytesHash, DirtyHash}; use rayon::prelude::*; @@ -28,28 +27,31 @@ use crate::prelude::sort::arg_sort_multiple::encode_rows_vertical; // This will remove the sorted flag on signed integers fn prepare_dataframe_unsorted(by: &[Series]) -> DataFrame { - DataFrame::new_no_checks( - by.iter() - .map(|s| match s.dtype() { - #[cfg(feature = "dtype-categorical")] - DataType::Categorical(_, _) | DataType::Enum(_, _) => { - s.cast(&DataType::UInt32).unwrap() - }, - _ => { - if s.dtype().to_physical().is_numeric() { - let s = s.to_physical_repr(); - if s.bit_repr_is_large() { - s.bit_repr_large().into_series() - } else { - s.bit_repr_small().into_series() - } + let columns = by + .iter() + .map(|s| match s.dtype() { + #[cfg(feature = "dtype-categorical")] + DataType::Categorical(_, _) | DataType::Enum(_, _) => { + s.cast(&DataType::UInt32).unwrap() + }, + _ => { + if s.dtype().to_physical().is_numeric() { + let s = s.to_physical_repr(); + + if s.dtype().is_float() { + s.into_owned().into_series() + } else if s.bit_repr_is_large() { + s.bit_repr_large().into_series() } else { - s.clone() + s.bit_repr_small().into_series() } - }, - }) - .collect(), - ) + } else { + s.clone() + } + }, + }) + .collect(); + unsafe { DataFrame::new_no_checks(columns) } } impl DataFrame { @@ -793,7 +795,7 @@ impl<'df> GroupBy<'df> { new_cols.extend_from_slice(&self.selected_keys); let cols = self.df.select_series(agg)?; new_cols.extend(cols); - Ok(DataFrame::new_no_checks(new_cols)) + Ok(unsafe { DataFrame::new_no_checks(new_cols) }) } } else { Ok(self.df.clone()) diff --git a/crates/polars-core/src/frame/group_by/perfect.rs b/crates/polars-core/src/frame/group_by/perfect.rs index 9853b52d5a93e..f120192404116 100644 --- a/crates/polars-core/src/frame/group_by/perfect.rs +++ b/crates/polars-core/src/frame/group_by/perfect.rs @@ -1,12 +1,10 @@ use std::fmt::Debug; -use arrow::array::Array; use arrow::legacy::bit_util::round_upto_multiple_of_64; use num_traits::{FromPrimitive, ToPrimitive}; use polars_utils::idx_vec::IdxVec; use polars_utils::slice::GetSaferUnchecked; use polars_utils::sync::SyncPtr; -use polars_utils::IdxSize; use rayon::prelude::*; #[cfg(all(feature = "dtype-categorical", feature = "performant"))] @@ -198,7 +196,7 @@ impl CategoricalChunked { let mut out = match &**rev_map { RevMapping::Local(cached, _) => { - if self.can_fast_unique() { + if self._can_fast_unique() { if verbose() { eprintln!("grouping categoricals, run perfect hash function"); } diff --git a/crates/polars-core/src/frame/group_by/proxy.rs b/crates/polars-core/src/frame/group_by/proxy.rs index 5615ed9f2a73e..e1988c363712b 100644 --- a/crates/polars-core/src/frame/group_by/proxy.rs +++ b/crates/polars-core/src/frame/group_by/proxy.rs @@ -1,7 +1,6 @@ use std::mem::ManuallyDrop; use std::ops::Deref; -use arrow::legacy::utils::CustomIterTools; use polars_utils::idx_vec::IdxVec; use polars_utils::sync::SyncPtr; use rayon::iter::plumbing::UnindexedConsumer; diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index ebd137d9c6e2f..1f6001ca3fd31 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -1,6 +1,5 @@ //! DataFrame module. use std::borrow::Cow; -use std::iter::{FromIterator, Iterator}; use std::{mem, ops}; use ahash::AHashSet; @@ -28,8 +27,6 @@ pub use chunks::*; use serde::{Deserialize, Serialize}; use smartstring::alias::String as SmartString; -#[cfg(feature = "algorithm_group_by")] -use crate::frame::group_by::GroupsIndicator; #[cfg(feature = "row_hash")] use crate::hashing::_df_rows_to_hashes_threaded_vertical; #[cfg(feature = "zip_with")] @@ -313,7 +310,8 @@ impl DataFrame { /// static EMPTY: DataFrame = DataFrame::empty(); /// ``` pub const fn empty() -> Self { - DataFrame::new_no_checks(Vec::new()) + // SAFETY: An empty dataframe cannot have length mismatches or duplicate names + unsafe { DataFrame::new_no_checks(Vec::new()) } } /// Removes the last `Series` from the `DataFrame` and returns it, or [`None`] if it is empty. @@ -400,24 +398,26 @@ impl DataFrame { /// Create a new `DataFrame` but does not check the length or duplicate occurrence of the `Series`. /// - /// It is advised to use [DataFrame::new](DataFrame::new) in favor of this method. + /// It is advised to use [DataFrame::new] in favor of this method. + /// + /// # Safety /// - /// # Panic /// It is the callers responsibility to uphold the contract of all `Series` - /// having an equal length, if not this may panic down the line. - pub const fn new_no_checks(columns: Vec) -> DataFrame { + /// having an equal length and a unique name, if not this may panic down the line. + pub const unsafe fn new_no_checks(columns: Vec) -> DataFrame { DataFrame { columns } } /// Create a new `DataFrame` but does not check the length of the `Series`, /// only check for duplicates. /// - /// It is advised to use [DataFrame::new](DataFrame::new) in favor of this method. + /// It is advised to use [DataFrame::new] in favor of this method. + /// + /// # Safety /// - /// # Panic /// It is the callers responsibility to uphold the contract of all `Series` /// having an equal length, if not this may panic down the line. - pub fn new_no_length_checks(columns: Vec) -> PolarsResult { + pub unsafe fn new_no_length_checks(columns: Vec) -> PolarsResult { let mut names = PlHashSet::with_capacity(columns.len()); for column in &columns { let name = column.name(); @@ -437,7 +437,7 @@ impl DataFrame { // Don't parallelize this. Memory overhead let f = |s: &Series| s.rechunk(); let cols = self.columns.iter().map(f).collect(); - DataFrame::new_no_checks(cols) + unsafe { DataFrame::new_no_checks(cols) } } /// Shrink the capacity of this DataFrame to fit its length. @@ -545,6 +545,7 @@ impl DataFrame { #[inline] /// Get mutable access to the underlying columns. + /// /// # Safety /// The caller must ensure the length of all [`Series`] remains equal. pub unsafe fn get_columns_mut(&mut self) -> &mut Vec { @@ -932,7 +933,7 @@ impl DataFrame { "unable to append to a DataFrame of width {} with a DataFrame of width {}", self.width(), other.width(), ); - self.columns = other.columns.clone(); + self.columns.clone_from(&other.columns); return Ok(self); } @@ -1088,7 +1089,7 @@ impl DataFrame { } }); - Ok(DataFrame::new_no_checks(new_cols)) + Ok(unsafe { DataFrame::new_no_checks(new_cols) }) } /// Drop columns that are in `names`. @@ -1106,7 +1107,7 @@ impl DataFrame { } }); - DataFrame::new_no_checks(new_cols) + unsafe { DataFrame::new_no_checks(new_cols) } } /// Insert a new column at a given index without checking for duplicates. @@ -1454,7 +1455,7 @@ impl DataFrame { pub fn _select_impl_unchecked(&self, cols: &[SmartString]) -> PolarsResult { let selected = self.select_series_impl(cols)?; - Ok(DataFrame::new_no_checks(selected)) + Ok(unsafe { DataFrame::new_no_checks(selected) }) } /// Select with a known schema. @@ -1497,7 +1498,7 @@ impl DataFrame { self.select_check_duplicates(cols)?; } let selected = self.select_series_impl_with_schema(cols, schema)?; - Ok(DataFrame::new_no_checks(selected)) + Ok(unsafe { DataFrame::new_no_checks(selected) }) } /// A non generic implementation to reduce compiler bloat. @@ -1529,7 +1530,7 @@ impl DataFrame { fn select_physical_impl(&self, cols: &[SmartString]) -> PolarsResult { self.select_check_duplicates(cols)?; let selected = self.select_series_physical_impl(cols)?; - Ok(DataFrame::new_no_checks(selected)) + Ok(unsafe { DataFrame::new_no_checks(selected) }) } fn select_check_duplicates(&self, cols: &[SmartString]) -> PolarsResult<()> { @@ -1645,7 +1646,7 @@ impl DataFrame { .iter() .map(|s| s.filter(mask)) .collect::>()?; - Ok(DataFrame::new_no_checks(cols)) + Ok(unsafe { DataFrame::new_no_checks(cols) }) }) .collect() }); @@ -1674,13 +1675,13 @@ impl DataFrame { return self.clone().filter_vertical(mask); } let new_col = self.try_apply_columns_par(&|s| s.filter(mask))?; - Ok(DataFrame::new_no_checks(new_col)) + Ok(unsafe { DataFrame::new_no_checks(new_col) }) } /// Same as `filter` but does not parallelize. pub fn _filter_seq(&self, mask: &BooleanChunked) -> PolarsResult { let new_col = self.try_apply_columns(&|s| s.filter(mask))?; - Ok(DataFrame::new_no_checks(new_col)) + Ok(unsafe { DataFrame::new_no_checks(new_col) }) } /// Take [`DataFrame`] rows by index values. @@ -1697,7 +1698,7 @@ impl DataFrame { pub fn take(&self, indices: &IdxCa) -> PolarsResult { let new_col = POOL.install(|| self.try_apply_columns_par(&|s| s.take(indices)))?; - Ok(DataFrame::new_no_checks(new_col)) + Ok(unsafe { DataFrame::new_no_checks(new_col) }) } /// # Safety @@ -1706,13 +1707,15 @@ impl DataFrame { self.take_unchecked_impl(idx, true) } - unsafe fn take_unchecked_impl(&self, idx: &IdxCa, allow_threads: bool) -> Self { + /// # Safety + /// The indices must be in-bounds. + pub unsafe fn take_unchecked_impl(&self, idx: &IdxCa, allow_threads: bool) -> Self { let cols = if allow_threads { POOL.install(|| self._apply_columns_par(&|s| s.take_unchecked(idx))) } else { self.columns.iter().map(|s| s.take_unchecked(idx)).collect() }; - DataFrame::new_no_checks(cols) + unsafe { DataFrame::new_no_checks(cols) } } pub(crate) unsafe fn take_slice_unchecked(&self, idx: &[IdxSize]) -> Self { @@ -1728,7 +1731,7 @@ impl DataFrame { .map(|s| s.take_slice_unchecked(idx)) .collect() }; - DataFrame::new_no_checks(cols) + unsafe { DataFrame::new_no_checks(cols) } } /// Rename a column in the [`DataFrame`]. @@ -2257,12 +2260,12 @@ impl DataFrame { .iter() .map(|s| s.slice(offset, length)) .collect::>(); - DataFrame::new_no_checks(col) + unsafe { DataFrame::new_no_checks(col) } } pub fn clear(&self) -> Self { let col = self.columns.iter().map(|s| s.clear()).collect::>(); - DataFrame::new_no_checks(col) + unsafe { DataFrame::new_no_checks(col) } } #[must_use] @@ -2270,7 +2273,8 @@ impl DataFrame { if offset == 0 && length == self.height() { return self.clone(); } - DataFrame::new_no_checks(self._apply_columns_par(&|s| s.slice(offset, length))) + let columns = self._apply_columns_par(&|s| s.slice(offset, length)); + unsafe { DataFrame::new_no_checks(columns) } } #[must_use] @@ -2278,11 +2282,12 @@ impl DataFrame { if offset == 0 && length == self.height() { return self.clone(); } - DataFrame::new_no_checks(self._apply_columns(&|s| { + let columns = self._apply_columns(&|s| { let mut out = s.slice(offset, length); out.shrink_to_fit(); out - })) + }); + unsafe { DataFrame::new_no_checks(columns) } } /// Get the head of the [`DataFrame`]. @@ -2325,7 +2330,7 @@ impl DataFrame { .iter() .map(|s| s.head(length)) .collect::>(); - DataFrame::new_no_checks(col) + unsafe { DataFrame::new_no_checks(col) } } /// Get the tail of the [`DataFrame`]. @@ -2365,7 +2370,7 @@ impl DataFrame { .iter() .map(|s| s.tail(length)) .collect::>(); - DataFrame::new_no_checks(col) + unsafe { DataFrame::new_no_checks(col) } } /// Iterator over the rows in this [`DataFrame`] as Arrow RecordBatches. @@ -2405,7 +2410,7 @@ impl DataFrame { #[must_use] pub fn reverse(&self) -> Self { let col = self.columns.iter().map(|s| s.reverse()).collect::>(); - DataFrame::new_no_checks(col) + unsafe { DataFrame::new_no_checks(col) } } /// Shift the values by a given period and fill the parts that will be empty due to this operation @@ -2416,7 +2421,7 @@ impl DataFrame { pub fn shift(&self, periods: i64) -> Self { let col = self._apply_columns_par(&|s| s.shift(periods)); - DataFrame::new_no_checks(col) + unsafe { DataFrame::new_no_checks(col) } } /// Replace None values with one of the following strategies: @@ -2430,7 +2435,7 @@ impl DataFrame { pub fn fill_null(&self, strategy: FillNullStrategy) -> PolarsResult { let col = self.try_apply_columns_par(&|s| s.fill_null(strategy))?; - Ok(DataFrame::new_no_checks(col)) + Ok(unsafe { DataFrame::new_no_checks(col) }) } /// Aggregate the column horizontally to their min values. @@ -2563,7 +2568,7 @@ impl DataFrame { }) .cloned() .collect(); - let numeric_df = DataFrame::new_no_checks(columns); + let numeric_df = unsafe { DataFrame::new_no_checks(columns) }; let sum = || numeric_df.sum_horizontal(null_strategy); @@ -2745,7 +2750,7 @@ impl DataFrame { return df.filter(&mask); }, }; - Ok(DataFrame::new_no_checks(columns)) + Ok(unsafe { DataFrame::new_no_checks(columns) }) } /// Get a mask of all the unique rows in the [`DataFrame`]. @@ -2806,7 +2811,7 @@ impl DataFrame { .iter() .map(|s| Series::new(s.name(), &[s.null_count() as IdxSize])) .collect(); - Self::new_no_checks(cols) + unsafe { Self::new_no_checks(cols) } } /// Hash and combine the row values @@ -3035,7 +3040,7 @@ impl Iterator for PhysRecordBatchIter<'_> { impl Default for DataFrame { fn default() -> Self { - DataFrame::new_no_checks(vec![]) + DataFrame::empty() } } @@ -3058,7 +3063,6 @@ fn ensure_can_extend(left: &Series, right: &Series) -> PolarsResult<()> { #[cfg(test)] mod test { use super::*; - use crate::frame::NullStrategy; fn create_frame() -> DataFrame { let s0 = Series::new("days", [0, 1, 2].as_ref()); diff --git a/crates/polars-core/src/frame/row/dataframe.rs b/crates/polars-core/src/frame/row/dataframe.rs index 1aa2197d1ac54..266677b14defc 100644 --- a/crates/polars-core/src/frame/row/dataframe.rs +++ b/crates/polars-core/src/frame/row/dataframe.rs @@ -1,5 +1,4 @@ use super::*; -use crate::frame::row::av_buffer::AnyValueBuffer; impl DataFrame { /// Get a row from a [`DataFrame`]. Use of this is discouraged as it will likely be slow. diff --git a/crates/polars-core/src/frame/row/transpose.rs b/crates/polars-core/src/frame/row/transpose.rs index 05b21ddb3a8fc..0fdc15c9c6f6d 100644 --- a/crates/polars-core/src/frame/row/transpose.rs +++ b/crates/polars-core/src/frame/row/transpose.rs @@ -79,7 +79,7 @@ impl DataFrame { })); }, }; - Ok(DataFrame::new_no_checks(cols_t)) + Ok(unsafe { DataFrame::new_no_checks(cols_t) }) } /// Transpose a DataFrame. This is a very expensive operation. @@ -224,37 +224,36 @@ where }) }); - cols_t.par_extend(POOL.install(|| { - values_buf - .into_par_iter() - .zip(validity_buf) - .zip(names_out) - .map(|((mut values, validity), name)| { - // SAFETY: - // all values are written we can now set len - unsafe { - values.set_len(new_height); - } + let par_iter = values_buf + .into_par_iter() + .zip(validity_buf) + .zip(names_out) + .map(|((mut values, validity), name)| { + // SAFETY: + // all values are written we can now set len + unsafe { + values.set_len(new_height); + } - let validity = if has_nulls { - let validity = Bitmap::from_trusted_len_iter(validity.iter().copied()); - if validity.unset_bits() > 0 { - Some(validity) - } else { - None - } + let validity = if has_nulls { + let validity = Bitmap::from_trusted_len_iter(validity.iter().copied()); + if validity.unset_bits() > 0 { + Some(validity) } else { None - }; + } + } else { + None + }; - let arr = PrimitiveArray::::new( - T::get_dtype().to_arrow(true), - values.into(), - validity, - ); - ChunkedArray::with_chunk(name.as_str(), arr).into_series() - }) - })); + let arr = PrimitiveArray::::new( + T::get_dtype().to_arrow(true), + values.into(), + validity, + ); + ChunkedArray::with_chunk(name.as_str(), arr).into_series() + }); + POOL.install(|| cols_t.par_extend(par_iter)); } #[cfg(test)] diff --git a/crates/polars-core/src/frame/top_k.rs b/crates/polars-core/src/frame/top_k.rs index b72116821dc9e..e201d1abb40aa 100644 --- a/crates/polars-core/src/frame/top_k.rs +++ b/crates/polars-core/src/frame/top_k.rs @@ -1,12 +1,8 @@ use std::cmp::Ordering; -use polars_error::PolarsResult; use polars_utils::iter::EnumerateIdxTrait; -use polars_utils::IdxSize; use smartstring::alias::String as SmartString; -use crate::datatypes::IdxCa; -use crate::frame::DataFrame; use crate::prelude::sort::_broadcast_descending; use crate::prelude::sort::arg_sort_multiple::_get_rows_encoded; use crate::prelude::*; diff --git a/crates/polars-core/src/frame/upstream_traits.rs b/crates/polars-core/src/frame/upstream_traits.rs index 21f5a0e74f84d..e2f28aefdb330 100644 --- a/crates/polars-core/src/frame/upstream_traits.rs +++ b/crates/polars-core/src/frame/upstream_traits.rs @@ -1,4 +1,3 @@ -use std::iter::FromIterator; use std::ops::{Index, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive}; use crate::prelude::*; diff --git a/crates/polars-core/src/functions.rs b/crates/polars-core/src/functions.rs index 6c802e02656c5..06fcea6612198 100644 --- a/crates/polars-core/src/functions.rs +++ b/crates/polars-core/src/functions.rs @@ -75,7 +75,7 @@ pub fn concat_df_diagonal(dfs: &[DataFrame]) -> PolarsResult { None => columns.push(Series::full_null(name, height, dtype)), } } - DataFrame::new_no_checks(columns) + unsafe { DataFrame::new_no_checks(columns) } }) .collect::>(); diff --git a/crates/polars-core/src/hashing/vector_hasher.rs b/crates/polars-core/src/hashing/vector_hasher.rs index 4b882bb2ce5e2..1c7635e701b1d 100644 --- a/crates/polars-core/src/hashing/vector_hasher.rs +++ b/crates/polars-core/src/hashing/vector_hasher.rs @@ -1,14 +1,13 @@ use arrow::bitmap::utils::get_bit_unchecked; #[cfg(feature = "group_by_list")] use arrow::legacy::kernels::list_bytes_iter::numeric_list_bytes_iter; +use polars_utils::total_ord::{ToTotalOrd, TotalHash}; use rayon::prelude::*; use xxhash_rust::xxh3::xxh3_64_with_seed; use super::*; -use crate::datatypes::UInt64Chunked; use crate::prelude::*; use crate::series::implementations::null::NullChunked; -use crate::utils::arrow::array::Array; use crate::POOL; // See: https://github.com/tkaitchuck/aHash/blob/f9acd508bd89e7c5b2877a9510098100f9018d64/src/operations.rs#L4 @@ -67,10 +66,11 @@ fn insert_null_hash(chunks: &[ArrayRef], random_state: RandomState, buf: &mut Ve }); } -fn integer_vec_hash(ca: &ChunkedArray, random_state: RandomState, buf: &mut Vec) +fn numeric_vec_hash(ca: &ChunkedArray, random_state: RandomState, buf: &mut Vec) where - T: PolarsIntegerType, - T::Native: Hash, + T: PolarsNumericType, + T::Native: TotalHash + ToTotalOrd, + ::TotalOrdItem: Hash, { // Note that we don't use the no null branch! This can break in unexpected ways. // for instance with threading we split an array in n_threads, this may lead to @@ -89,16 +89,17 @@ where .as_slice() .iter() .copied() - .map(|v| random_state.hash_one(v)), + .map(|v| random_state.hash_one(v.to_total_ord())), ); }); insert_null_hash(&ca.chunks, random_state, buf) } -fn integer_vec_hash_combine(ca: &ChunkedArray, random_state: RandomState, hashes: &mut [u64]) +fn numeric_vec_hash_combine(ca: &ChunkedArray, random_state: RandomState, hashes: &mut [u64]) where - T: PolarsIntegerType, - T::Native: Hash, + T: PolarsNumericType, + T::Native: TotalHash + ToTotalOrd, + ::TotalOrdItem: Hash, { let null_h = get_null_hash_value(&random_state); @@ -111,7 +112,7 @@ where .iter() .zip(&mut hashes[offset..]) .for_each(|(v, h)| { - *h = folded_multiply(random_state.hash_one(v) ^ *h, MULTIPLE); + *h = folded_multiply(random_state.hash_one(v.to_total_ord()) ^ *h, MULTIPLE); }), _ => { let validity = arr.validity().unwrap(); @@ -121,7 +122,7 @@ where .zip(&mut hashes[offset..]) .zip(arr.values().as_slice()) .for_each(|((valid, h), l)| { - let lh = random_state.hash_one(l); + let lh = random_state.hash_one(l.to_total_ord()); let to_hash = [null_h, lh][valid as usize]; // inlined from ahash. This ensures we combine with the previous state @@ -133,11 +134,11 @@ where }); } -macro_rules! vec_hash_int { +macro_rules! vec_hash_numeric { ($ca:ident) => { impl VecHash for $ca { fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { - integer_vec_hash(self, random_state, buf); + numeric_vec_hash(self, random_state, buf); Ok(()) } @@ -146,21 +147,25 @@ macro_rules! vec_hash_int { random_state: RandomState, hashes: &mut [u64], ) -> PolarsResult<()> { - integer_vec_hash_combine(self, random_state, hashes); + numeric_vec_hash_combine(self, random_state, hashes); Ok(()) } } }; } -vec_hash_int!(Int64Chunked); -vec_hash_int!(Int32Chunked); -vec_hash_int!(Int16Chunked); -vec_hash_int!(Int8Chunked); -vec_hash_int!(UInt64Chunked); -vec_hash_int!(UInt32Chunked); -vec_hash_int!(UInt16Chunked); -vec_hash_int!(UInt8Chunked); +vec_hash_numeric!(Int64Chunked); +vec_hash_numeric!(Int32Chunked); +vec_hash_numeric!(Int16Chunked); +vec_hash_numeric!(Int8Chunked); +vec_hash_numeric!(UInt64Chunked); +vec_hash_numeric!(UInt32Chunked); +vec_hash_numeric!(UInt16Chunked); +vec_hash_numeric!(UInt8Chunked); +vec_hash_numeric!(Float64Chunked); +vec_hash_numeric!(Float32Chunked); +#[cfg(feature = "dtype-decimal")] +vec_hash_numeric!(Int128Chunked); impl VecHash for StringChunked { fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { @@ -370,30 +375,6 @@ impl VecHash for BooleanChunked { } } -impl VecHash for Float32Chunked { - fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { - self.bit_repr_small().vec_hash(random_state, buf)?; - Ok(()) - } - - fn vec_hash_combine(&self, random_state: RandomState, hashes: &mut [u64]) -> PolarsResult<()> { - self.bit_repr_small() - .vec_hash_combine(random_state, hashes)?; - Ok(()) - } -} -impl VecHash for Float64Chunked { - fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { - self.bit_repr_large().vec_hash(random_state, buf)?; - Ok(()) - } - fn vec_hash_combine(&self, random_state: RandomState, hashes: &mut [u64]) -> PolarsResult<()> { - self.bit_repr_large() - .vec_hash_combine(random_state, hashes)?; - Ok(()) - } -} - #[cfg(feature = "group_by_list")] impl VecHash for ListChunked { fn vec_hash(&self, _random_state: RandomState, _buf: &mut Vec) -> PolarsResult<()> { diff --git a/crates/polars-core/src/lib.rs b/crates/polars-core/src/lib.rs index 954d9fd09ea18..c5f4316b37b6c 100644 --- a/crates/polars-core/src/lib.rs +++ b/crates/polars-core/src/lib.rs @@ -47,6 +47,7 @@ pub static PROCESS_ID: Lazy = Lazy::new(|| { // this is re-exported in utils for polars child crates #[cfg(not(target_family = "wasm"))] // only use this on non wasm targets pub static POOL: Lazy = Lazy::new(|| { + let thread_name = std::env::var("POLARS_THREAD_NAME").unwrap_or_else(|_| "polars".to_string()); ThreadPoolBuilder::new() .num_threads( std::env::var("POLARS_MAX_THREADS") @@ -57,6 +58,7 @@ pub static POOL: Lazy = Lazy::new(|| { .get() }), ) + .thread_name(move |i| format!("{}-{}", thread_name, i)) .build() .expect("could not spawn threads") }); diff --git a/crates/polars-core/src/prelude.rs b/crates/polars-core/src/prelude.rs index a4bfe06ca3074..69db7deb7e3dc 100644 --- a/crates/polars-core/src/prelude.rs +++ b/crates/polars-core/src/prelude.rs @@ -8,8 +8,7 @@ pub use arrow::datatypes::{ArrowSchema, Field as ArrowField}; pub use arrow::legacy::kernels::ewm::EWMOptions; pub use arrow::legacy::prelude::*; pub(crate) use arrow::trusted_len::TrustedLen; -#[cfg(feature = "chunked_ids")] -pub(crate) use polars_utils::index::ChunkId; +pub use polars_utils::index::{ChunkId, IdxSize, NullableChunkId, NullableIdxSize}; pub(crate) use polars_utils::total_ord::{TotalEq, TotalOrd}; pub use crate::chunked_array::arithmetic::ArithmeticChunked; diff --git a/crates/polars-core/src/random.rs b/crates/polars-core/src/random.rs index 1851440631970..f6fd3c1f3978d 100644 --- a/crates/polars-core/src/random.rs +++ b/crates/polars-core/src/random.rs @@ -2,7 +2,6 @@ use std::sync::Mutex; use once_cell::sync::Lazy; use rand::prelude::*; -use rand::rngs::SmallRng; static POLARS_GLOBAL_RNG_STATE: Lazy> = Lazy::new(|| Mutex::new(SmallRng::from_entropy())); diff --git a/crates/polars-core/src/series/any_value.rs b/crates/polars-core/src/series/any_value.rs index d34bad6512df5..280a66cae5644 100644 --- a/crates/polars-core/src/series/any_value.rs +++ b/crates/polars-core/src/series/any_value.rs @@ -1,14 +1,353 @@ use std::fmt::Write; +#[cfg(feature = "object")] +use crate::chunked_array::object::registry::ObjectRegistry; use crate::prelude::*; -use crate::utils::get_supertype; +use crate::utils::try_get_supertype; -fn any_values_to_primitive(avs: &[AnyValue]) -> ChunkedArray { - avs.iter() +impl<'a, T: AsRef<[AnyValue<'a>]>> NamedFrom]> for Series { + fn new(name: &str, v: T) -> Self { + let av = v.as_ref(); + Series::from_any_values(name, av, true).unwrap() + } +} + +impl Series { + /// Construct a new [`Series`] from a slice of AnyValues. + /// + /// The data type of the resulting Series is determined by the `values` + /// and the `strict` parameter: + /// - If `strict` is `true`, the data type is equal to the data type of the + /// first non-null value. If any other non-null values do not match this + /// data type, an error is raised. + /// - If `strict` is `false`, the data type is the supertype of the `values`. + /// An error is returned if no supertype can be determined. + /// **WARNING**: A full pass over the values is required to determine the supertype. + /// - If no values were passed, the resulting data type is `Null`. + pub fn from_any_values(name: &str, values: &[AnyValue], strict: bool) -> PolarsResult { + fn get_first_non_null_dtype(values: &[AnyValue]) -> DataType { + let mut all_flat_null = true; + let first_non_null = values.iter().find(|av| { + if !av.is_null() { + all_flat_null = false + }; + !av.is_nested_null() + }); + match first_non_null { + Some(av) => av.dtype(), + None => { + if all_flat_null { + DataType::Null + } else { + // Second pass to check for the nested null value that + // toggled `all_flat_null` to false, e.g. a List(Null) + let first_nested_null = values.iter().find(|av| !av.is_null()).unwrap(); + first_nested_null.dtype() + } + }, + } + } + fn get_any_values_supertype(values: &[AnyValue]) -> PolarsResult { + let mut supertype = DataType::Null; + let mut dtypes = PlHashSet::::new(); + for av in values { + if dtypes.insert(av.dtype()) { + supertype = try_get_supertype(&supertype, &av.dtype()).map_err(|_| { + polars_err!( + SchemaMismatch: + "failed to infer supertype of values; partial supertype is {:?}, found value of type {:?}: {}", + supertype, av.dtype(), av + ) + } + )?; + } + } + Ok(supertype) + } + + let dtype = if strict { + get_first_non_null_dtype(values) + } else { + get_any_values_supertype(values)? + }; + Self::from_any_values_and_dtype(name, values, &dtype, strict) + } + + /// Construct a new [`Series`]` with the given `dtype` from a slice of AnyValues. + /// + /// If `strict` is `true`, an error is returned if the values do not match the given + /// data type. If `strict` is `false`, values that do not match the given data type + /// are cast. If casting is not possible, the values are set to null instead.` + pub fn from_any_values_and_dtype( + name: &str, + values: &[AnyValue], + dtype: &DataType, + strict: bool, + ) -> PolarsResult { + let mut s = match dtype { + #[cfg(feature = "dtype-i8")] + DataType::Int8 => any_values_to_integer::(values, strict)?.into_series(), + #[cfg(feature = "dtype-i16")] + DataType::Int16 => any_values_to_integer::(values, strict)?.into_series(), + DataType::Int32 => any_values_to_integer::(values, strict)?.into_series(), + DataType::Int64 => any_values_to_integer::(values, strict)?.into_series(), + #[cfg(feature = "dtype-u8")] + DataType::UInt8 => any_values_to_integer::(values, strict)?.into_series(), + #[cfg(feature = "dtype-u16")] + DataType::UInt16 => any_values_to_integer::(values, strict)?.into_series(), + DataType::UInt32 => any_values_to_integer::(values, strict)?.into_series(), + DataType::UInt64 => any_values_to_integer::(values, strict)?.into_series(), + DataType::Float32 => any_values_to_f32(values, strict)?.into_series(), + DataType::Float64 => any_values_to_f64(values, strict)?.into_series(), + DataType::Boolean => any_values_to_bool(values, strict)?.into_series(), + DataType::String => any_values_to_string(values, strict)?.into_series(), + DataType::Binary => any_values_to_binary(values, strict)?.into_series(), + #[cfg(feature = "dtype-date")] + DataType::Date => any_values_to_primitive_nonstrict::(values) + .into_date() + .into_series(), + #[cfg(feature = "dtype-datetime")] + DataType::Datetime(tu, tz) => any_values_to_primitive_nonstrict::(values) + .into_datetime(*tu, (*tz).clone()) + .into_series(), + #[cfg(feature = "dtype-time")] + DataType::Time => any_values_to_primitive_nonstrict::(values) + .into_time() + .into_series(), + #[cfg(feature = "dtype-duration")] + DataType::Duration(tu) => any_values_to_primitive_nonstrict::(values) + .into_duration(*tu) + .into_series(), + #[cfg(feature = "dtype-categorical")] + dt @ (DataType::Categorical(_, _) | DataType::Enum(_, _)) => { + any_values_to_categorical(values, dt, strict)? + }, + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(precision, scale) => { + any_values_to_decimal(values, *precision, *scale)?.into_series() + }, + DataType::List(inner) => any_values_to_list(values, inner, strict)?.into_series(), + #[cfg(feature = "dtype-array")] + DataType::Array(inner, size) => any_values_to_array(values, inner, strict, *size)? + .into_series() + .cast(&DataType::Array(inner.clone(), *size))?, + #[cfg(feature = "dtype-struct")] + DataType::Struct(fields) => any_values_to_struct(values, fields, strict)?, + #[cfg(feature = "object")] + DataType::Object(_, registry) => any_values_to_object(values, registry)?, + DataType::Null => Series::new_null(name, values.len()), + dt => { + polars_bail!( + InvalidOperation: + "constructing a Series with data type {dt:?} from AnyValues is not supported" + ) + }, + }; + s.rename(name); + Ok(s) + } +} + +fn any_values_to_primitive_nonstrict(values: &[AnyValue]) -> ChunkedArray { + values + .iter() .map(|av| av.extract::()) .collect_trusted() } +fn any_values_to_integer( + values: &[AnyValue], + strict: bool, +) -> PolarsResult> { + fn any_values_to_integer_strict( + values: &[AnyValue], + ) -> PolarsResult> { + let mut builder = PrimitiveChunkedBuilder::::new("", values.len()); + for av in values { + match av { + av if av.is_integer() => { + let opt_val = av.extract::(); + let val = match opt_val { + Some(v) => v, + None => return Err(invalid_value_error(&T::get_dtype(), av)), + }; + builder.append_value(val) + }, + AnyValue::Null => builder.append_null(), + av => return Err(invalid_value_error(&T::get_dtype(), av)), + } + } + Ok(builder.finish()) + } + if strict { + any_values_to_integer_strict::(values) + } else { + Ok(any_values_to_primitive_nonstrict::(values)) + } +} + +fn any_values_to_f32(values: &[AnyValue], strict: bool) -> PolarsResult { + fn any_values_to_f32_strict(values: &[AnyValue]) -> PolarsResult { + let mut builder = PrimitiveChunkedBuilder::::new("", values.len()); + for av in values { + match av { + AnyValue::Float32(i) => builder.append_value(*i), + AnyValue::Null => builder.append_null(), + av => return Err(invalid_value_error(&DataType::Float32, av)), + } + } + Ok(builder.finish()) + } + if strict { + any_values_to_f32_strict(values) + } else { + Ok(any_values_to_primitive_nonstrict::(values)) + } +} +fn any_values_to_f64(values: &[AnyValue], strict: bool) -> PolarsResult { + fn any_values_to_f64_strict(values: &[AnyValue]) -> PolarsResult { + let mut builder = PrimitiveChunkedBuilder::::new("", values.len()); + for av in values { + match av { + AnyValue::Float64(i) => builder.append_value(*i), + AnyValue::Float32(i) => builder.append_value(*i as f64), + AnyValue::Null => builder.append_null(), + av => return Err(invalid_value_error(&DataType::Float64, av)), + } + } + Ok(builder.finish()) + } + if strict { + any_values_to_f64_strict(values) + } else { + Ok(any_values_to_primitive_nonstrict::(values)) + } +} + +fn any_values_to_bool(values: &[AnyValue], strict: bool) -> PolarsResult { + fn any_values_to_bool_strict(values: &[AnyValue]) -> PolarsResult { + let mut builder = BooleanChunkedBuilder::new("", values.len()); + for av in values { + match av { + AnyValue::Boolean(b) => builder.append_value(*b), + AnyValue::Null => builder.append_null(), + av => return Err(invalid_value_error(&DataType::Boolean, av)), + } + } + Ok(builder.finish()) + } + fn any_values_to_bool_nonstrict(values: &[AnyValue]) -> BooleanChunked { + let mapper = |av: &AnyValue| match av { + AnyValue::Boolean(b) => Some(*b), + AnyValue::Null => None, + av => match av.cast(&DataType::Boolean) { + AnyValue::Boolean(b) => Some(b), + _ => None, + }, + }; + values.iter().map(mapper).collect_trusted() + } + if strict { + any_values_to_bool_strict(values) + } else { + Ok(any_values_to_bool_nonstrict(values)) + } +} + +fn any_values_to_string(values: &[AnyValue], strict: bool) -> PolarsResult { + fn any_values_to_string_strict(values: &[AnyValue]) -> PolarsResult { + let mut builder = StringChunkedBuilder::new("", values.len()); + for av in values { + match av { + AnyValue::String(s) => builder.append_value(s), + AnyValue::StringOwned(s) => builder.append_value(s), + AnyValue::Null => builder.append_null(), + av => return Err(invalid_value_error(&DataType::String, av)), + } + } + Ok(builder.finish()) + } + fn any_values_to_string_nonstrict(values: &[AnyValue]) -> StringChunked { + let mut builder = StringChunkedBuilder::new("", values.len()); + let mut owned = String::new(); // Amortize allocations + for av in values { + match av { + AnyValue::String(s) => builder.append_value(s), + AnyValue::StringOwned(s) => builder.append_value(s), + AnyValue::Null => builder.append_null(), + AnyValue::Binary(_) | AnyValue::BinaryOwned(_) => builder.append_null(), + av => { + owned.clear(); + write!(owned, "{av}").unwrap(); + builder.append_value(&owned); + }, + } + } + builder.finish() + } + if strict { + any_values_to_string_strict(values) + } else { + Ok(any_values_to_string_nonstrict(values)) + } +} + +fn any_values_to_binary(values: &[AnyValue], strict: bool) -> PolarsResult { + fn any_values_to_binary_strict(values: &[AnyValue]) -> PolarsResult { + let mut builder = BinaryChunkedBuilder::new("", values.len()); + for av in values { + match av { + AnyValue::Binary(s) => builder.append_value(*s), + AnyValue::BinaryOwned(s) => builder.append_value(&**s), + AnyValue::Null => builder.append_null(), + av => return Err(invalid_value_error(&DataType::Binary, av)), + } + } + Ok(builder.finish()) + } + fn any_values_to_binary_nonstrict(values: &[AnyValue]) -> BinaryChunked { + values + .iter() + .map(|av| match av { + AnyValue::Binary(b) => Some(*b), + AnyValue::BinaryOwned(b) => Some(&**b), + AnyValue::String(s) => Some(s.as_bytes()), + AnyValue::StringOwned(s) => Some(s.as_bytes()), + _ => None, + }) + .collect_trusted() + } + if strict { + any_values_to_binary_strict(values) + } else { + Ok(any_values_to_binary_nonstrict(values)) + } +} + +#[cfg(feature = "dtype-categorical")] +fn any_values_to_categorical( + values: &[AnyValue], + dtype: &DataType, + strict: bool, +) -> PolarsResult { + let ca = if let Some(single_av) = values.first() { + match single_av { + AnyValue::String(_) | AnyValue::StringOwned(_) | AnyValue::Null => { + any_values_to_string(values, strict)? + }, + _ => polars_bail!( + ComputeError: + "categorical dtype with any-values of dtype {} not supported", + single_av.dtype() + ), + } + } else { + StringChunked::full("", "", 0) + }; + + ca.cast(dtype) +} + #[cfg(feature = "dtype-decimal")] fn any_values_to_decimal( avs: &[AnyValue], @@ -76,50 +415,40 @@ fn any_values_to_decimal( builder.finish().into_decimal(precision, scale) } -#[cfg(feature = "dtype-array")] -fn any_values_to_array( +fn any_values_to_list( avs: &[AnyValue], inner_type: &DataType, strict: bool, - width: usize, -) -> PolarsResult { - fn to_arr(s: &Series) -> Option { - if s.chunks().len() > 1 { - let s = s.rechunk(); - Some(s.chunks()[0].clone()) - } else { - Some(s.chunks()[0].clone()) - } - } +) -> PolarsResult { + let target_dtype = DataType::List(Box::new(inner_type.clone())); // this is handled downstream. The builder will choose the first non null type let mut valid = true; #[allow(unused_mut)] - let mut out: ArrayChunked = if inner_type == &DataType::Null { + let mut out: ListChunked = if inner_type == &DataType::Null { avs.iter() .map(|av| match av { - AnyValue::List(b) | AnyValue::Array(b, _) => to_arr(b), + AnyValue::List(b) => Some(b.clone()), AnyValue::Null => None, _ => { valid = false; None }, }) - .collect_ca_with_dtype("", DataType::Array(Box::new(inner_type.clone()), width)) + .collect_trusted() } // make sure that wrongly inferred AnyValues don't deviate from the datatype else { avs.iter() .map(|av| match av { - AnyValue::List(b) | AnyValue::Array(b, _) => { + AnyValue::List(b) => { if b.dtype() == inner_type { - to_arr(b) + Some(b.clone()) } else { - let s = match b.cast(inner_type) { - Ok(out) => out, - Err(_) => Series::full_null(b.name(), b.len(), inner_type), - }; - to_arr(&s) + match b.cast(inner_type) { + Ok(out) => Some(out), + Err(_) => Some(Series::full_null(b.name(), b.len(), inner_type)), + } } }, AnyValue::Null => None, @@ -128,60 +457,70 @@ fn any_values_to_array( None }, }) - .collect_ca_with_dtype("", DataType::Array(Box::new(inner_type.clone()), width)) + .collect_trusted() }; - if let DataType::Array(_, s) = out.dtype() { - polars_ensure!(*s == width, ComputeError: "got mixed size array widths where width {} was expected", width) + + if strict && !valid { + polars_bail!(SchemaMismatch: "unexpected value while building Series of type {:?}", target_dtype); } + // Ensure the logical type is correct for nested types #[cfg(feature = "dtype-struct")] - if !matches!(inner_type, DataType::Null) - && matches!(out.inner_dtype(), DataType::Struct(_) | DataType::List(_)) - { - // ensure the logical type is correct + if !matches!(inner_type, DataType::Null) && out.inner_dtype().is_nested() { unsafe { - out.set_dtype(DataType::Array(Box::new(inner_type.clone()), width)); + out.set_dtype(target_dtype.clone()); }; } - if valid || !strict { - Ok(out) - } else { - polars_bail!(ComputeError: "got mixed dtypes while constructing List Series") - } + + Ok(out) } -fn any_values_to_list( +#[cfg(feature = "dtype-array")] +fn any_values_to_array( avs: &[AnyValue], inner_type: &DataType, strict: bool, -) -> PolarsResult { + width: usize, +) -> PolarsResult { + fn to_arr(s: &Series) -> Option { + if s.chunks().len() > 1 { + let s = s.rechunk(); + Some(s.chunks()[0].clone()) + } else { + Some(s.chunks()[0].clone()) + } + } + + let target_dtype = DataType::Array(Box::new(inner_type.clone()), width); + // this is handled downstream. The builder will choose the first non null type let mut valid = true; #[allow(unused_mut)] - let mut out: ListChunked = if inner_type == &DataType::Null { + let mut out: ArrayChunked = if inner_type == &DataType::Null { avs.iter() .map(|av| match av { - AnyValue::List(b) => Some(b.clone()), + AnyValue::List(b) | AnyValue::Array(b, _) => to_arr(b), AnyValue::Null => None, _ => { valid = false; None }, }) - .collect_trusted() + .collect_ca_with_dtype("", target_dtype.clone()) } // make sure that wrongly inferred AnyValues don't deviate from the datatype else { avs.iter() .map(|av| match av { - AnyValue::List(b) => { + AnyValue::List(b) | AnyValue::Array(b, _) => { if b.dtype() == inner_type { - Some(b.clone()) + to_arr(b) } else { - match b.cast(inner_type) { - Ok(out) => Some(out), - Err(_) => Some(Series::full_null(b.name(), b.len(), inner_type)), - } + let s = match b.cast(inner_type) { + Ok(out) => out, + Err(_) => Series::full_null(b.name(), b.len(), inner_type), + }; + to_arr(&s) } }, AnyValue::Null => None, @@ -190,374 +529,142 @@ fn any_values_to_list( None }, }) - .collect_trusted() + .collect_ca_with_dtype("", target_dtype.clone()) }; + + if strict && !valid { + polars_bail!(SchemaMismatch: "unexpected value while building Series of type {:?}", target_dtype); + } + polars_ensure!( + out.width() == width, + SchemaMismatch: "got mixed size array widths where width {} was expected", width + ); + + // Ensure the logical type is correct for nested types #[cfg(feature = "dtype-struct")] - if !matches!(inner_type, DataType::Null) - && matches!(out.inner_dtype(), DataType::Struct(_) | DataType::List(_)) - { - // ensure the logical type is correct + if !matches!(inner_type, DataType::Null) && out.inner_dtype().is_nested() { unsafe { - out.set_dtype(DataType::List(Box::new(inner_type.clone()))); + out.set_dtype(target_dtype.clone()); }; } - if valid || !strict { - Ok(out) - } else { - polars_bail!(ComputeError: "got mixed dtypes while constructing List Series") - } + + Ok(out) } -impl<'a, T: AsRef<[AnyValue<'a>]>> NamedFrom]> for Series { - fn new(name: &str, v: T) -> Self { - let av = v.as_ref(); - Series::from_any_values(name, av, true).unwrap() +#[cfg(feature = "dtype-struct")] +fn any_values_to_struct( + values: &[AnyValue], + fields: &[Field], + strict: bool, +) -> PolarsResult { + // Fast path for empty structs. + if fields.is_empty() { + return Ok(StructChunked::full_null("", values.len()).into_series()); } -} -impl Series { - /// Construct a new [`Series`]` with the given `dtype` from a slice of AnyValues. - pub fn from_any_values_and_dtype( - name: &str, - av: &[AnyValue], - dtype: &DataType, - strict: bool, - ) -> PolarsResult { - let mut s = match dtype { - #[cfg(feature = "dtype-i8")] - DataType::Int8 => any_values_to_primitive::(av).into_series(), - #[cfg(feature = "dtype-i16")] - DataType::Int16 => any_values_to_primitive::(av).into_series(), - DataType::Int32 => any_values_to_primitive::(av).into_series(), - DataType::Int64 => any_values_to_primitive::(av).into_series(), - #[cfg(feature = "dtype-u8")] - DataType::UInt8 => any_values_to_primitive::(av).into_series(), - #[cfg(feature = "dtype-u16")] - DataType::UInt16 => any_values_to_primitive::(av).into_series(), - DataType::UInt32 => any_values_to_primitive::(av).into_series(), - DataType::UInt64 => any_values_to_primitive::(av).into_series(), - DataType::Float32 => any_values_to_primitive::(av).into_series(), - DataType::Float64 => any_values_to_primitive::(av).into_series(), - DataType::String => any_values_to_string(av, strict)?.into_series(), - DataType::Binary => any_values_to_binary(av, strict)?.into_series(), - DataType::Boolean => any_values_to_bool(av, strict)?.into_series(), - #[cfg(feature = "dtype-date")] - DataType::Date => any_values_to_primitive::(av) - .into_date() - .into_series(), - #[cfg(feature = "dtype-datetime")] - DataType::Datetime(tu, tz) => any_values_to_primitive::(av) - .into_datetime(*tu, (*tz).clone()) - .into_series(), - #[cfg(feature = "dtype-time")] - DataType::Time => any_values_to_primitive::(av) - .into_time() - .into_series(), - #[cfg(feature = "dtype-duration")] - DataType::Duration(tu) => any_values_to_primitive::(av) - .into_duration(*tu) - .into_series(), - #[cfg(feature = "dtype-decimal")] - DataType::Decimal(precision, scale) => { - any_values_to_decimal(av, *precision, *scale)?.into_series() - }, - DataType::List(inner) => any_values_to_list(av, inner, strict)?.into_series(), - #[cfg(feature = "dtype-array")] - DataType::Array(inner, size) => any_values_to_array(av, inner, strict, *size)? - .into_series() - .cast(&DataType::Array(inner.clone(), *size))?, - #[cfg(feature = "dtype-struct")] - DataType::Struct(dtype_fields) => { - // fast path for empty structs - if dtype_fields.is_empty() { - return Ok(StructChunked::full_null(name, av.len()).into_series()); - } - // the physical series fields of the struct - let mut series_fields = Vec::with_capacity(dtype_fields.len()); - for (i, field) in dtype_fields.iter().enumerate() { - let mut field_avs = Vec::with_capacity(av.len()); - - for av in av.iter() { - match av { - AnyValue::StructOwned(payload) => { - // TODO: optimize - let av_fields = &payload.1; - let av_values = &payload.0; - - let mut append_by_search = || { - // search for the name - let mut pushed = false; - for (av_fld, av_val) in av_fields.iter().zip(av_values) { - if av_fld.name == field.name { - field_avs.push(av_val.clone()); - pushed = true; - break; - } - } - if !pushed { - field_avs.push(AnyValue::Null) - } - }; - - // all fields are available in this single value - // we can use the index to get value - if dtype_fields.len() == av_fields.len() { - let mut search = false; - for (l, r) in dtype_fields.iter().zip(av_fields.iter()) { - if l.name() != r.name() { - search = true; - } - } - if search { - append_by_search() - } else { - let av_val = - av_values.get(i).cloned().unwrap_or(AnyValue::Null); - field_avs.push(av_val) - } - } - // not all fields are available, we search the proper field - else { - // search for the name - append_by_search() - } - }, - _ => field_avs.push(AnyValue::Null), + // The physical series fields of the struct. + let mut series_fields = Vec::with_capacity(fields.len()); + for (i, field) in fields.iter().enumerate() { + let mut field_avs = Vec::with_capacity(values.len()); + + for av in values.iter() { + match av { + AnyValue::StructOwned(payload) => { + // TODO: Optimize. + let av_fields = &payload.1; + let av_values = &payload.0; + + let mut append_by_search = || { + // Search for the name. + let mut pushed = false; + for (av_fld, av_val) in av_fields.iter().zip(av_values) { + if av_fld.name == field.name { + field_avs.push(av_val.clone()); + pushed = true; + break; + } + } + if !pushed { + field_avs.push(AnyValue::Null) } - } - // if the inferred dtype is null, we let auto inference work - let s = if matches!(field.dtype, DataType::Null) { - Series::new(field.name(), &field_avs) - } else { - Series::from_any_values_and_dtype( - field.name(), - &field_avs, - &field.dtype, - strict, - )? }; - series_fields.push(s) - } - return StructChunked::new(name, &series_fields).map(|ca| ca.into_series()); - }, - #[cfg(feature = "object")] - DataType::Object(_, registry) => { - match registry { - None => { - use crate::chunked_array::object::registry; - let converter = registry::get_object_converter(); - let mut builder = registry::get_object_builder(name, av.len()); - for av in av { - match av { - AnyValue::Object(val) => builder.append_value(val.as_any()), - AnyValue::Null => builder.append_null(), - _ => { - // This is needed because in python people can send mixed types. - // This only works if you set a global converter. - let any = converter(av.as_borrowed()); - builder.append_value(&*any) - }, + + // All fields are available in this single value. + // We can use the index to get value. + if fields.len() == av_fields.len() { + let mut search = false; + for (l, r) in fields.iter().zip(av_fields.iter()) { + if l.name() != r.name() { + search = true; } } - return Ok(builder.to_series()); - }, - Some(registry) => { - let mut builder = (*registry.builder_constructor)(name, av.len()); - for av in av { - match av { - AnyValue::Object(val) => builder.append_value(val.as_any()), - AnyValue::Null => builder.append_null(), - _ => { - polars_bail!(ComputeError: "expected object"); - }, - } + if search { + append_by_search() + } else { + let av_val = av_values.get(i).cloned().unwrap_or(AnyValue::Null); + field_avs.push(av_val) } - return Ok(builder.to_series()); - }, - } - }, - DataType::Null => Series::new_null(name, av.len()), - #[cfg(feature = "dtype-categorical")] - dt @ (DataType::Categorical(_, _) | DataType::Enum(_, _)) => { - let ca = if let Some(single_av) = av.first() { - match single_av { - AnyValue::String(_) | AnyValue::StringOwned(_) | AnyValue::Null => { - any_values_to_string(av, strict)? - }, - _ => polars_bail!( - ComputeError: - "categorical dtype with any-values of dtype {} not supported", - single_av.dtype() - ), } - } else { - StringChunked::full("", "", 0) - }; - - ca.cast(dt).unwrap() - }, - dt => panic!("{dt:?} not supported"), - }; - s.rename(name); - Ok(s) - } - - /// Construct a new [`Series`] from a slice of AnyValues. - /// - /// The data type of the resulting Series is determined by the `values` - /// and the `strict` parameter: - /// - If `strict` is `true`, the data type is equal to the data type of the - /// first non-null value. If any other non-null values do not match this - /// data type, an error is raised. - /// - If `strict` is `false`, the data type is the supertype of the - /// `values`. **WARNING**: A full pass over the values is required to - /// determine the supertype. Values encountered that do not match the - /// supertype are set to null. - /// - If no values were passed, the resulting data type is `Null`. - pub fn from_any_values(name: &str, values: &[AnyValue], strict: bool) -> PolarsResult { - fn get_first_non_null_dtype(values: &[AnyValue]) -> DataType { - let mut all_flat_null = true; - let first_non_null = values.iter().find(|av| { - if !av.is_null() { - all_flat_null = false - }; - !av.is_nested_null() - }); - match first_non_null { - Some(av) => av.dtype(), - None => { - if all_flat_null { - DataType::Null - } else { - // Second pass to check for the nested null value that - // toggled `all_flat_null` to false, e.g. a List(Null) - let first_nested_null = values.iter().find(|av| !av.is_null()).unwrap(); - first_nested_null.dtype() + // Not all fields are available, we search the proper field. + else { + // Search for the name. + append_by_search() } }, + _ => field_avs.push(AnyValue::Null), } } - fn get_any_values_supertype(values: &[AnyValue]) -> DataType { - let mut supertype = DataType::Null; - let mut dtypes = PlHashSet::::new(); - for av in values { - if dtypes.insert(av.dtype()) { - // Values with incompatible data types will be set to null later - if let Some(st) = get_supertype(&supertype, &av.dtype()) { - supertype = st; - } - } - } - supertype - } - - let dtype = if strict { - get_first_non_null_dtype(values) + // If the inferred dtype is null, we let auto inference work. + let s = if matches!(field.dtype, DataType::Null) { + Series::new(field.name(), &field_avs) } else { - get_any_values_supertype(values) + Series::from_any_values_and_dtype(field.name(), &field_avs, &field.dtype, strict)? }; - Self::from_any_values_and_dtype(name, values, &dtype, strict) + series_fields.push(s) } + StructChunked::new("", &series_fields).map(|ca| ca.into_series()) } -fn any_values_to_bool(values: &[AnyValue], strict: bool) -> PolarsResult { - if strict { - any_values_to_bool_strict(values) - } else { - Ok(any_values_to_bool_nonstrict(values)) - } -} -fn any_values_to_bool_strict(values: &[AnyValue]) -> PolarsResult { - let mut builder = BooleanChunkedBuilder::new("", values.len()); - for av in values { - match av { - AnyValue::Boolean(b) => builder.append_value(*b), - AnyValue::Null => builder.append_null(), - av => return Err(invalid_value_error(&DataType::Boolean, av)), - } - } - Ok(builder.finish()) -} -fn any_values_to_bool_nonstrict(values: &[AnyValue]) -> BooleanChunked { - let mapper = |av: &AnyValue| match av { - AnyValue::Boolean(b) => Some(*b), - AnyValue::Null => None, - av => match av.cast(&DataType::Boolean) { - AnyValue::Boolean(b) => Some(b), - _ => None, +#[cfg(feature = "object")] +fn any_values_to_object( + values: &[AnyValue], + registry: &Option>, +) -> PolarsResult { + let mut builder = match registry { + None => { + use crate::chunked_array::object::registry; + let converter = registry::get_object_converter(); + let mut builder = registry::get_object_builder("", values.len()); + for av in values { + match av { + AnyValue::Object(val) => builder.append_value(val.as_any()), + AnyValue::Null => builder.append_null(), + _ => { + // This is needed because in Python users can send mixed types. + // This only works if you set a global converter. + let any = converter(av.as_borrowed()); + builder.append_value(&*any) + }, + } + } + builder + }, + Some(registry) => { + let mut builder = (*registry.builder_constructor)("", values.len()); + for av in values { + match av { + AnyValue::Object(val) => builder.append_value(val.as_any()), + AnyValue::Null => builder.append_null(), + _ => { + polars_bail!(ComputeError: "expected object"); + }, + } + } + builder }, }; - values.iter().map(mapper).collect_trusted() -} - -fn any_values_to_string(values: &[AnyValue], strict: bool) -> PolarsResult { - if strict { - any_values_to_string_strict(values) - } else { - Ok(any_values_to_string_nonstrict(values)) - } -} -fn any_values_to_string_strict(values: &[AnyValue]) -> PolarsResult { - let mut builder = StringChunkedBuilder::new("", values.len()); - for av in values { - match av { - AnyValue::String(s) => builder.append_value(s), - AnyValue::StringOwned(s) => builder.append_value(s), - AnyValue::Null => builder.append_null(), - av => return Err(invalid_value_error(&DataType::String, av)), - } - } - Ok(builder.finish()) -} -fn any_values_to_string_nonstrict(values: &[AnyValue]) -> StringChunked { - let mut builder = StringChunkedBuilder::new("", values.len()); - let mut owned = String::new(); // Amortize allocations - for av in values { - match av { - AnyValue::String(s) => builder.append_value(s), - AnyValue::StringOwned(s) => builder.append_value(s), - AnyValue::Null => builder.append_null(), - AnyValue::Binary(_) | AnyValue::BinaryOwned(_) => builder.append_null(), - av => { - owned.clear(); - write!(owned, "{av}").unwrap(); - builder.append_value(&owned); - }, - } - } - builder.finish() -} -fn any_values_to_binary(values: &[AnyValue], strict: bool) -> PolarsResult { - if strict { - any_values_to_binary_strict(values) - } else { - Ok(any_values_to_binary_nonstrict(values)) - } -} -fn any_values_to_binary_strict(values: &[AnyValue]) -> PolarsResult { - let mut builder = BinaryChunkedBuilder::new("", values.len()); - for av in values { - match av { - AnyValue::Binary(s) => builder.append_value(*s), - AnyValue::BinaryOwned(s) => builder.append_value(&**s), - AnyValue::Null => builder.append_null(), - av => return Err(invalid_value_error(&DataType::Binary, av)), - } - } - Ok(builder.finish()) -} -fn any_values_to_binary_nonstrict(values: &[AnyValue]) -> BinaryChunked { - values - .iter() - .map(|av| match av { - AnyValue::Binary(b) => Some(*b), - AnyValue::BinaryOwned(b) => Some(&**b), - AnyValue::String(s) => Some(s.as_bytes()), - AnyValue::StringOwned(s) => Some(s.as_bytes()), - _ => None, - }) - .collect_trusted() + Ok(builder.to_series()) } fn invalid_value_error(dtype: &DataType, value: &AnyValue) -> PolarsError { diff --git a/crates/polars-core/src/series/arithmetic/borrowed.rs b/crates/polars-core/src/series/arithmetic/borrowed.rs index 67bd8066ac791..22101912b285c 100644 --- a/crates/polars-core/src/series/arithmetic/borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/borrowed.rs @@ -708,21 +708,21 @@ where #[must_use] pub fn lhs_sub(&self, lhs: N) -> Self { let lhs: T::Native = NumCast::from(lhs).expect("could not cast"); - self.apply_values(|v| lhs - v) + ArithmeticChunked::wrapping_sub_scalar_lhs(lhs, self) } /// Apply lhs / self #[must_use] pub fn lhs_div(&self, lhs: N) -> Self { let lhs: T::Native = NumCast::from(lhs).expect("could not cast"); - self.apply_values(|v| lhs / v) + ArithmeticChunked::legacy_div_scalar_lhs(lhs, self) } /// Apply lhs % self #[must_use] pub fn lhs_rem(&self, lhs: N) -> Self { let lhs: T::Native = NumCast::from(lhs).expect("could not cast"); - self.apply_values(|v| lhs % v) + ArithmeticChunked::wrapping_mod_scalar_lhs(lhs, self) } } diff --git a/crates/polars-core/src/series/comparison.rs b/crates/polars-core/src/series/comparison.rs index efa8726b2b998..15c891aef935c 100644 --- a/crates/polars-core/src/series/comparison.rs +++ b/crates/polars-core/src/series/comparison.rs @@ -3,8 +3,6 @@ #[cfg(feature = "dtype-struct")] use std::ops::Deref; -use super::Series; -use crate::apply_method_physical_numeric; use crate::prelude::*; use crate::series::arithmetic::coerce_lhs_rhs; use crate::series::nulls::replace_non_null; diff --git a/crates/polars-core/src/series/from.rs b/crates/polars-core/src/series/from.rs index 86e4b2ae64c04..e8a729ab29d6a 100644 --- a/crates/polars-core/src/series/from.rs +++ b/crates/polars-core/src/series/from.rs @@ -1,5 +1,3 @@ -use std::convert::TryFrom; - use arrow::compute::cast::cast_unchecked as cast; use arrow::datatypes::Metadata; #[cfg(any(feature = "dtype-struct", feature = "dtype-categorical"))] diff --git a/crates/polars-core/src/series/implementations/array.rs b/crates/polars-core/src/series/implementations/array.rs index f853a113c4d59..164eeceb8ba77 100644 --- a/crates/polars-core/src/series/implementations/array.rs +++ b/crates/polars-core/src/series/implementations/array.rs @@ -1,7 +1,7 @@ use std::any::Any; use std::borrow::Cow; -use super::{private, IntoSeries, SeriesTrait}; +use super::private; use crate::chunked_array::comparison::*; use crate::chunked_array::ops::explode::ExplodeByOffsets; use crate::chunked_array::{AsSinglePtr, Settings}; diff --git a/crates/polars-core/src/series/implementations/binary.rs b/crates/polars-core/src/series/implementations/binary.rs index 830a53f937370..86705e8f9af3b 100644 --- a/crates/polars-core/src/series/implementations/binary.rs +++ b/crates/polars-core/src/series/implementations/binary.rs @@ -1,18 +1,8 @@ -use std::borrow::Cow; - -use ahash::RandomState; - -use super::{private, IntoSeries, SeriesTrait, *}; +use super::*; use crate::chunked_array::comparison::*; -use crate::chunked_array::ops::compare_inner::{ - IntoTotalEqInner, IntoTotalOrdInner, TotalEqInner, TotalOrdInner, -}; -use crate::chunked_array::ops::explode::ExplodeByOffsets; -use crate::chunked_array::AsSinglePtr; #[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; use crate::prelude::*; -use crate::series::implementations::SeriesWrap; impl private::PrivateSeries for SeriesWrap { fn compute_len(&mut self) { diff --git a/crates/polars-core/src/series/implementations/binary_offset.rs b/crates/polars-core/src/series/implementations/binary_offset.rs index a8af560c9d611..d0a5523c7d8cf 100644 --- a/crates/polars-core/src/series/implementations/binary_offset.rs +++ b/crates/polars-core/src/series/implementations/binary_offset.rs @@ -1,16 +1,8 @@ -use std::borrow::Cow; - -use ahash::RandomState; - -use super::{private, IntoSeries, SeriesTrait, *}; +use super::*; use crate::chunked_array::comparison::*; -use crate::chunked_array::ops::compare_inner::{ - IntoTotalEqInner, IntoTotalOrdInner, TotalEqInner, TotalOrdInner, -}; #[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; use crate::prelude::*; -use crate::series::implementations::SeriesWrap; impl private::PrivateSeries for SeriesWrap { fn compute_len(&mut self) { diff --git a/crates/polars-core/src/series/implementations/boolean.rs b/crates/polars-core/src/series/implementations/boolean.rs index e38f82204ca60..1aa17d298d3f6 100644 --- a/crates/polars-core/src/series/implementations/boolean.rs +++ b/crates/polars-core/src/series/implementations/boolean.rs @@ -1,19 +1,8 @@ -use std::borrow::Cow; -use std::ops::{BitAnd, BitOr, BitXor}; - -use ahash::RandomState; - -use super::{private, IntoSeries, SeriesTrait, *}; +use super::*; use crate::chunked_array::comparison::*; -use crate::chunked_array::ops::compare_inner::{ - IntoTotalEqInner, IntoTotalOrdInner, TotalEqInner, TotalOrdInner, -}; -use crate::chunked_array::ops::explode::ExplodeByOffsets; -use crate::chunked_array::{AsSinglePtr, ChunkIdIter}; #[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; use crate::prelude::*; -use crate::series::implementations::SeriesWrap; impl private::PrivateSeries for SeriesWrap { fn compute_len(&mut self) { diff --git a/crates/polars-core/src/series/implementations/categorical.rs b/crates/polars-core/src/series/implementations/categorical.rs index 1d7d50b636c2f..f9a23c2614177 100644 --- a/crates/polars-core/src/series/implementations/categorical.rs +++ b/crates/polars-core/src/series/implementations/categorical.rs @@ -1,14 +1,6 @@ -use std::borrow::Cow; - -use ahash::RandomState; - -use super::{private, IntoSeries, SeriesTrait, *}; +use super::*; use crate::chunked_array::comparison::*; -use crate::chunked_array::ops::compare_inner::{IntoTotalOrdInner, TotalOrdInner}; -use crate::chunked_array::ops::explode::ExplodeByOffsets; -use crate::chunked_array::AsSinglePtr; use crate::prelude::*; -use crate::series::implementations::SeriesWrap; unsafe impl IntoSeries for CategoricalChunked { fn into_series(self) -> Series { @@ -26,7 +18,7 @@ impl SeriesWrap { self.0.get_ordering(), ) }; - if keep_fast_unique && self.0.can_fast_unique() { + if keep_fast_unique && self.0._can_fast_unique() { out.set_fast_unique(true) } out diff --git a/crates/polars-core/src/series/implementations/dates_time.rs b/crates/polars-core/src/series/implementations/dates_time.rs index ca2ef989146db..5f5e993dcbbc4 100644 --- a/crates/polars-core/src/series/implementations/dates_time.rs +++ b/crates/polars-core/src/series/implementations/dates_time.rs @@ -7,15 +7,7 @@ //! opting for a little more run time cost. We cast to the physical type -> apply the operation and //! (depending on the result) cast back to the original type //! -use std::borrow::Cow; -use std::ops::Deref; - -use ahash::RandomState; - -use super::{private, IntoSeries, SeriesTrait, SeriesWrap, *}; -use crate::chunked_array::ops::explode::ExplodeByOffsets; -use crate::chunked_array::ops::ToBitRepr; -use crate::chunked_array::AsSinglePtr; +use super::*; #[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; use crate::prelude::*; diff --git a/crates/polars-core/src/series/implementations/datetime.rs b/crates/polars-core/src/series/implementations/datetime.rs index 7504a72692936..c4c8bfe1b47b4 100644 --- a/crates/polars-core/src/series/implementations/datetime.rs +++ b/crates/polars-core/src/series/implementations/datetime.rs @@ -1,11 +1,4 @@ -use std::borrow::Cow; -use std::ops::Deref; - -use ahash::RandomState; - -use super::{private, IntoSeries, SeriesTrait, SeriesWrap, *}; -use crate::chunked_array::ops::explode::ExplodeByOffsets; -use crate::chunked_array::AsSinglePtr; +use super::*; #[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; use crate::prelude::*; diff --git a/crates/polars-core/src/series/implementations/decimal.rs b/crates/polars-core/src/series/implementations/decimal.rs index 2f89924d319de..d50befa1059cc 100644 --- a/crates/polars-core/src/series/implementations/decimal.rs +++ b/crates/polars-core/src/series/implementations/decimal.rs @@ -1,4 +1,4 @@ -use super::{private, IntoSeries, SeriesTrait, SeriesWrap, *}; +use super::*; use crate::prelude::*; unsafe impl IntoSeries for DecimalChunked { @@ -18,11 +18,41 @@ impl SeriesWrap { fn agg_helper Series>(&self, f: F) -> Series { let agg_s = f(&self.0); - let ca = agg_s.decimal().unwrap(); - let ca = ca.as_ref().clone(); - let precision = self.0.precision(); - let scale = self.0.scale(); - ca.into_decimal_unchecked(precision, scale).into_series() + match agg_s.dtype() { + DataType::Decimal(_, _) => { + let ca = agg_s.decimal().unwrap(); + let ca = ca.as_ref().clone(); + let precision = self.0.precision(); + let scale = self.0.scale(); + ca.into_decimal_unchecked(precision, scale).into_series() + }, + DataType::List(dtype) if dtype.is_decimal() => { + let dtype = self.0.dtype(); + let ca = agg_s.list().unwrap(); + let arr = ca.downcast_iter().next().unwrap(); + // SAFETY: dtype is passed correctly + let s = unsafe { + Series::from_chunks_and_dtype_unchecked("", vec![arr.values().clone()], dtype) + }; + let new_values = s.array_ref(0).clone(); + let data_type = ListArray::::default_datatype(dtype.to_arrow(true)); + let new_arr = ListArray::::new( + data_type, + arr.offsets().clone(), + new_values, + arr.validity().cloned(), + ); + unsafe { + ListChunked::from_chunks_and_dtype_unchecked( + agg_s.name(), + vec![Box::new(new_arr)], + DataType::List(Box::new(self.dtype().clone())), + ) + .into_series() + } + }, + _ => unreachable!(), + } } } @@ -66,6 +96,22 @@ impl private::PrivateSeries for SeriesWrap { .into_decimal_unchecked(self.0.precision(), self.0.scale()) .into_series()) } + fn into_total_eq_inner<'a>(&'a self) -> Box { + (&self.0).into_total_eq_inner() + } + fn into_total_ord_inner<'a>(&'a self) -> Box { + (&self.0).into_total_ord_inner() + } + + fn vec_hash(&self, random_state: RandomState, buf: &mut Vec) -> PolarsResult<()> { + self.0.vec_hash(random_state, buf)?; + Ok(()) + } + + fn vec_hash_combine(&self, build_hasher: RandomState, hashes: &mut [u64]) -> PolarsResult<()> { + self.0.vec_hash_combine(build_hasher, hashes)?; + Ok(()) + } #[cfg(feature = "algorithm_group_by")] unsafe fn agg_sum(&self, groups: &GroupsProxy) -> Series { @@ -84,7 +130,7 @@ impl private::PrivateSeries for SeriesWrap { #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { - self.0.agg_list(groups) + self.agg_helper(|ca| ca.agg_list(groups)) } fn subtract(&self, rhs: &Series) -> PolarsResult { @@ -103,6 +149,10 @@ impl private::PrivateSeries for SeriesWrap { let rhs = rhs.decimal()?; ((&self.0) / rhs).map(|ca| ca.into_series()) } + #[cfg(feature = "algorithm_group_by")] + fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { + self.0.group_tuples(multithreaded, sorted) + } } impl SeriesTrait for SeriesWrap { @@ -211,6 +261,17 @@ impl SeriesTrait for SeriesWrap { self.0.get_any_value_unchecked(index) } + fn sort_with(&self, options: SortOptions) -> Series { + self.0 + .sort_with(options) + .into_decimal_unchecked(self.0.precision(), self.0.scale()) + .into_series() + } + + fn arg_sort(&self, options: SortOptions) -> IdxCa { + self.0.arg_sort(options) + } + fn null_count(&self) -> usize { self.0.null_count() } diff --git a/crates/polars-core/src/series/implementations/duration.rs b/crates/polars-core/src/series/implementations/duration.rs index 7a050630a5383..30e3f30857e05 100644 --- a/crates/polars-core/src/series/implementations/duration.rs +++ b/crates/polars-core/src/series/implementations/duration.rs @@ -1,12 +1,7 @@ -use std::borrow::Cow; -use std::ops::{Deref, DerefMut}; +use std::ops::DerefMut; -use ahash::RandomState; - -use super::{private, IntoSeries, SeriesTrait, SeriesWrap, *}; +use super::*; use crate::chunked_array::comparison::*; -use crate::chunked_array::ops::explode::ExplodeByOffsets; -use crate::chunked_array::AsSinglePtr; #[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; use crate::prelude::*; diff --git a/crates/polars-core/src/series/implementations/floats.rs b/crates/polars-core/src/series/implementations/floats.rs index b82356c9f5bc5..3332649da16b2 100644 --- a/crates/polars-core/src/series/implementations/floats.rs +++ b/crates/polars-core/src/series/implementations/floats.rs @@ -1,22 +1,8 @@ -use std::any::Any; -use std::borrow::Cow; - -use ahash::RandomState; -use arrow::legacy::prelude::QuantileInterpolOptions; - -use super::{private, IntoSeries, SeriesTrait, SeriesWrap, *}; +use super::*; use crate::chunked_array::comparison::*; -use crate::chunked_array::ops::aggregate::{ChunkAggSeries, QuantileAggSeries, VarAggSeries}; -use crate::chunked_array::ops::compare_inner::{ - IntoTotalEqInner, IntoTotalOrdInner, TotalEqInner, TotalOrdInner, -}; -use crate::chunked_array::ops::explode::ExplodeByOffsets; -use crate::chunked_array::AsSinglePtr; #[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; use crate::prelude::*; -#[cfg(feature = "checked_arithmetic")] -use crate::series::arithmetic::checked::NumOpsDispatchChecked; macro_rules! impl_dyn_series { ($ca: ident) => { diff --git a/crates/polars-core/src/series/implementations/list.rs b/crates/polars-core/src/series/implementations/list.rs index 183f572ee2b28..da0c2ce27366a 100644 --- a/crates/polars-core/src/series/implementations/list.rs +++ b/crates/polars-core/src/series/implementations/list.rs @@ -1,20 +1,8 @@ -use std::any::Any; -use std::borrow::Cow; - -#[cfg(feature = "group_by_list")] -use ahash::RandomState; - use super::*; use crate::chunked_array::comparison::*; -use crate::chunked_array::ops::compare_inner::{IntoTotalEqInner, TotalEqInner}; -use crate::chunked_array::ops::explode::ExplodeByOffsets; -use crate::chunked_array::{AsSinglePtr, Settings}; #[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; use crate::prelude::*; -use crate::series::implementations::SeriesWrap; -#[cfg(feature = "chunked_ids")] -use crate::series::IsSorted; impl private::PrivateSeries for SeriesWrap { fn compute_len(&mut self) { diff --git a/crates/polars-core/src/series/implementations/mod.rs b/crates/polars-core/src/series/implementations/mod.rs index 4fb80344a59da..dce50670131cc 100644 --- a/crates/polars-core/src/series/implementations/mod.rs +++ b/crates/polars-core/src/series/implementations/mod.rs @@ -28,22 +28,17 @@ mod struct_; use std::any::Any; use std::borrow::Cow; -use std::ops::{BitAnd, BitOr, BitXor, Deref}; +use std::ops::{BitAnd, BitOr, BitXor}; use ahash::RandomState; -use arrow::legacy::prelude::QuantileInterpolOptions; -use super::{private, IntoSeries, SeriesTrait, *}; +use super::*; use crate::chunked_array::comparison::*; -use crate::chunked_array::ops::aggregate::{ChunkAggSeries, QuantileAggSeries, VarAggSeries}; use crate::chunked_array::ops::compare_inner::{ IntoTotalEqInner, IntoTotalOrdInner, TotalEqInner, TotalOrdInner, }; use crate::chunked_array::ops::explode::ExplodeByOffsets; use crate::chunked_array::AsSinglePtr; -use crate::prelude::*; -#[cfg(feature = "checked_arithmetic")] -use crate::series::arithmetic::checked::NumOpsDispatchChecked; // Utility wrapper struct pub(crate) struct SeriesWrap(pub T); diff --git a/crates/polars-core/src/series/implementations/null.rs b/crates/polars-core/src/series/implementations/null.rs index 75cf6d4eb2c95..0550b360a7a86 100644 --- a/crates/polars-core/src/series/implementations/null.rs +++ b/crates/polars-core/src/series/implementations/null.rs @@ -1,13 +1,7 @@ use std::any::Any; -use std::borrow::Cow; -use std::sync::Arc; -use arrow::array::ArrayRef; use polars_error::constants::LENGTH_LIMIT_MSG; -use polars_utils::IdxSize; -use crate::datatypes::IdxCa; -use crate::error::PolarsResult; use crate::prelude::compare_inner::{IntoTotalEqInner, TotalEqInner}; use crate::prelude::explode::ExplodeByOffsets; use crate::prelude::*; diff --git a/crates/polars-core/src/series/implementations/object.rs b/crates/polars-core/src/series/implementations/object.rs index 6434c66d782c2..d60b474c376d9 100644 --- a/crates/polars-core/src/series/implementations/object.rs +++ b/crates/polars-core/src/series/implementations/object.rs @@ -6,8 +6,6 @@ use ahash::RandomState; use crate::chunked_array::object::PolarsObjectSafe; use crate::chunked_array::ops::compare_inner::{IntoTotalEqInner, TotalEqInner}; use crate::chunked_array::Settings; -#[cfg(feature = "algorithm_group_by")] -use crate::frame::group_by::{GroupsProxy, IntoGroupsProxy}; use crate::prelude::*; use crate::series::implementations::SeriesWrap; use crate::series::private::{PrivateSeries, PrivateSeriesNumeric}; diff --git a/crates/polars-core/src/series/implementations/string.rs b/crates/polars-core/src/series/implementations/string.rs index b43bd2dcaba77..9a8c1b1f6aa41 100644 --- a/crates/polars-core/src/series/implementations/string.rs +++ b/crates/polars-core/src/series/implementations/string.rs @@ -1,18 +1,8 @@ -use std::borrow::Cow; - -use ahash::RandomState; - -use super::{private, IntoSeries, SeriesTrait, *}; +use super::*; use crate::chunked_array::comparison::*; -use crate::chunked_array::ops::compare_inner::{ - IntoTotalEqInner, IntoTotalOrdInner, TotalEqInner, TotalOrdInner, -}; -use crate::chunked_array::ops::explode::ExplodeByOffsets; -use crate::chunked_array::AsSinglePtr; #[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; use crate::prelude::*; -use crate::series::implementations::SeriesWrap; impl private::PrivateSeries for SeriesWrap { fn compute_len(&mut self) { diff --git a/crates/polars-core/src/series/implementations/struct_.rs b/crates/polars-core/src/series/implementations/struct_.rs index 69e075a08ec09..fcf85754aac75 100644 --- a/crates/polars-core/src/series/implementations/struct_.rs +++ b/crates/polars-core/src/series/implementations/struct_.rs @@ -1,5 +1,3 @@ -use std::any::Any; - use super::*; use crate::hashing::series_to_hashes; use crate::prelude::*; @@ -65,7 +63,7 @@ impl private::PrivateSeries for SeriesWrap { #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { - let df = DataFrame::new_no_checks(vec![]); + let df = DataFrame::empty(); let gb = df .group_by_with_series(self.0.fields().to_vec(), multithreaded, sorted) .unwrap(); diff --git a/crates/polars-core/src/series/mod.rs b/crates/polars-core/src/series/mod.rs index 2e184655440e2..32a697f550903 100644 --- a/crates/polars-core/src/series/mod.rs +++ b/crates/polars-core/src/series/mod.rs @@ -16,7 +16,6 @@ pub mod unstable; use std::borrow::Cow; use std::hash::{Hash, Hasher}; use std::ops::Deref; -use std::sync::Arc; use ahash::RandomState; use arrow::compute::aggregate::estimated_bytes_size; @@ -221,7 +220,8 @@ impl Series { } pub fn into_frame(self) -> DataFrame { - DataFrame::new_no_checks(vec![self]) + // SAFETY: A single-column dataframe cannot have length mismatches or duplicate names + unsafe { DataFrame::new_no_checks(vec![self]) } } /// Rename series. @@ -448,7 +448,10 @@ impl Series { Date => Cow::Owned(self.cast(&Int32).unwrap()), Datetime(_, _) | Duration(_) | Time => Cow::Owned(self.cast(&Int64).unwrap()), #[cfg(feature = "dtype-categorical")] - Categorical(_, _) | Enum(_, _) => Cow::Owned(self.cast(&UInt32).unwrap()), + Categorical(_, _) | Enum(_, _) => { + let ca = self.categorical().unwrap(); + Cow::Owned(ca.physical().clone().into_series()) + }, List(inner) => Cow::Owned(self.cast(&List(Box::new(inner.to_physical()))).unwrap()), #[cfg(feature = "dtype-struct")] Struct(_) => { @@ -596,7 +599,7 @@ impl Series { /// /// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16}` the `Series` is /// first cast to `Int64` to prevent overflow issues. - pub fn product(&self) -> Series { + pub fn product(&self) -> PolarsResult { #[cfg(feature = "product")] { use DataType::*; @@ -606,11 +609,13 @@ impl Series { let s = self.cast(&Int64).unwrap(); s.product() }, - Int64 => self.i64().unwrap().prod_as_series(), - UInt64 => self.u64().unwrap().prod_as_series(), - Float32 => self.f32().unwrap().prod_as_series(), - Float64 => self.f64().unwrap().prod_as_series(), - dt => panic!("product not supported for dtype: {dt:?}"), + Int64 => Ok(self.i64().unwrap().prod_as_series()), + UInt64 => Ok(self.u64().unwrap().prod_as_series()), + Float32 => Ok(self.f32().unwrap().prod_as_series()), + Float64 => Ok(self.f64().unwrap().prod_as_series()), + dt => { + polars_bail!(InvalidOperation: "`product` operation not supported for dtype `{dt}`") + }, } } #[cfg(not(feature = "product"))] @@ -908,8 +913,6 @@ where #[cfg(test)] mod test { - use std::convert::TryFrom; - use crate::prelude::*; use crate::series::*; diff --git a/crates/polars-core/src/series/ops/to_list.rs b/crates/polars-core/src/series/ops/to_list.rs index e35cc7a3c93cb..3b1c7f757be7b 100644 --- a/crates/polars-core/src/series/ops/to_list.rs +++ b/crates/polars-core/src/series/ops/to_list.rs @@ -126,7 +126,6 @@ impl Series { #[cfg(test)] mod test { use super::*; - use crate::chunked_array::builder::get_list_builder; #[test] fn test_to_list() -> PolarsResult<()> { diff --git a/crates/polars-core/src/series/series_trait.rs b/crates/polars-core/src/series/series_trait.rs index 7dc88b2a1f795..eb976d0d9b78e 100644 --- a/crates/polars-core/src/series/series_trait.rs +++ b/crates/polars-core/src/series/series_trait.rs @@ -1,9 +1,6 @@ use std::any::Any; use std::borrow::Cow; -#[cfg(feature = "temporal")] -use std::sync::Arc; -use arrow::legacy::prelude::QuantileInterpolOptions; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -46,8 +43,6 @@ pub(crate) mod private { use super::*; use crate::chunked_array::ops::compare_inner::{TotalEqInner, TotalOrdInner}; use crate::chunked_array::Settings; - #[cfg(feature = "algorithm_group_by")] - use crate::frame::group_by::GroupsProxy; pub trait PrivateSeriesNumeric { fn bit_repr_is_large(&self) -> bool { @@ -212,6 +207,7 @@ pub trait SeriesTrait: fn chunks(&self) -> &Vec; /// Underlying chunks. + /// /// # Safety /// The caller must ensure the length and the data types of `ArrayRef` does not change. unsafe fn chunks_mut(&mut self) -> &mut Vec; @@ -459,6 +455,7 @@ pub trait SeriesTrait: #[cfg(feature = "object")] /// Get the value at this index as a downcastable Any trait ref. + /// /// # Safety /// This function doesn't do any bound checks. unsafe fn get_object_chunked_unchecked( diff --git a/crates/polars-core/src/series/unstable.rs b/crates/polars-core/src/series/unstable.rs index ef85d673a6dca..d4bed3cb10350 100644 --- a/crates/polars-core/src/series/unstable.rs +++ b/crates/polars-core/src/series/unstable.rs @@ -45,6 +45,7 @@ impl<'a> UnstableSeries<'a> { } /// Creates a new `[UnsafeSeries]` + /// /// # Safety /// Inner chunks must be from `Series` otherwise the dtype may be incorrect and lead to UB. #[inline] diff --git a/crates/polars-core/src/testing.rs b/crates/polars-core/src/testing.rs index 82003da6f0c23..99c28a617b2bc 100644 --- a/crates/polars-core/src/testing.rs +++ b/crates/polars-core/src/testing.rs @@ -164,6 +164,18 @@ impl PartialEq for DataFrame { } } +/// Asserts that two expressions of type [`DataFrame`] are equal according to [`DataFrame::equals`] +/// at runtime. If the expression are not equal, the program will panic with a message that displays +/// both dataframes. +#[macro_export] +macro_rules! assert_df_eq { + ($a:expr, $b:expr $(,)?) => { + let a: &$crate::frame::DataFrame = &$a; + let b: &$crate::frame::DataFrame = &$b; + assert!(a.equals(b), "expected {:?}\nto equal {:?}", a, b); + }; +} + #[cfg(test)] mod test { use crate::prelude::*; @@ -194,6 +206,19 @@ mod test { assert!(df1.equals(&df1)) } + #[test] + fn assert_df_eq_passes() { + let df = df!("a" => [1], "b" => [2]).unwrap(); + assert_df_eq!(df, df); + drop(df); // Ensure `assert_df_eq!` does not consume its arguments. + } + + #[test] + #[should_panic(expected = "to equal")] + fn assert_df_eq_panics() { + assert_df_eq!(df!("a" => [1]).unwrap(), df!("a" => [2]).unwrap(),); + } + #[test] fn test_df_partialeq() { let df1 = df!("a" => &[1, 2, 3], diff --git a/crates/polars-core/src/utils/flatten.rs b/crates/polars-core/src/utils/flatten.rs index 024e20de93545..54eadf5d74a55 100644 --- a/crates/polars-core/src/utils/flatten.rs +++ b/crates/polars-core/src/utils/flatten.rs @@ -1,23 +1,24 @@ +use arrow::bitmap::MutableBitmap; use polars_utils::sync::SyncPtr; use super::*; pub fn flatten_df_iter(df: &DataFrame) -> impl Iterator + '_ { df.iter_chunks_physical().flat_map(|chunk| { - let df = DataFrame::new_no_checks( - df.iter() - .zip(chunk.into_arrays()) - .map(|(s, arr)| { - // SAFETY: - // datatypes are correct - let mut out = unsafe { - Series::from_chunks_and_dtype_unchecked(s.name(), vec![arr], s.dtype()) - }; - out.set_sorted_flag(s.is_sorted_flag()); - out - }) - .collect(), - ); + let columns = df + .iter() + .zip(chunk.into_arrays()) + .map(|(s, arr)| { + // SAFETY: + // datatypes are correct + let mut out = unsafe { + Series::from_chunks_and_dtype_unchecked(s.name(), vec![arr], s.dtype()) + }; + out.set_sorted_flag(s.is_sorted_flag()); + out + }) + .collect(); + let df = unsafe { DataFrame::new_no_checks(columns) }; if df.height() == 0 { None } else { @@ -89,3 +90,31 @@ fn flatten_par_impl( } out } + +pub fn flatten_nullable + Send + Sync>( + bufs: &[S], +) -> PrimitiveArray { + let a = || flatten_par(bufs); + let b = || { + let cap = bufs.iter().map(|s| s.as_ref().len()).sum::(); + let mut validity = MutableBitmap::with_capacity(cap); + validity.extend_constant(cap, true); + + let mut count = 0usize; + for s in bufs { + let s = s.as_ref(); + + for id in s { + if id.is_null_idx() { + unsafe { validity.set_bit_unchecked(count, false) }; + } + + count += 1; + } + } + validity.freeze() + }; + + let (a, b) = POOL.join(a, b); + PrimitiveArray::from_vec(bytemuck::cast_vec::<_, IdxSize>(a)).with_validity(Some(b)) +} diff --git a/crates/polars-core/src/utils/mod.rs b/crates/polars-core/src/utils/mod.rs index bb7da62bf0d9d..1ccf29942bcc0 100644 --- a/crates/polars-core/src/utils/mod.rs +++ b/crates/polars-core/src/utils/mod.rs @@ -133,7 +133,11 @@ pub fn split_series(s: &Series, n: usize) -> PolarsResult> { split_array!(s, n, i64) } -pub fn split_df_as_ref(df: &DataFrame, n: usize) -> PolarsResult> { +pub fn split_df_as_ref( + df: &DataFrame, + n: usize, + extend_sub_chunks: bool, +) -> PolarsResult> { let total_len = df.height(); let chunk_size = std::cmp::max(total_len / n, 1); @@ -155,7 +159,7 @@ pub fn split_df_as_ref(df: &DataFrame, n: usize) -> PolarsResult> chunk_size }; let df = df.slice((i * chunk_size) as i64, len); - if df.n_chunks() > 1 { + if extend_sub_chunks && df.n_chunks() > 1 { // we add every chunk as separate dataframe. This make sure that every partition // deals with it. out.extend(flatten_df_iter(&df)) @@ -175,7 +179,7 @@ pub fn split_df(df: &mut DataFrame, n: usize) -> PolarsResult> { } // make sure that chunks are aligned. df.align_chunks(); - split_df_as_ref(df, n) + split_df_as_ref(df, n, true) } pub fn slice_slice(vals: &[T], offset: i64, len: usize) -> &[T] { @@ -498,18 +502,6 @@ macro_rules! apply_method_all_arrow_series { } } -#[macro_export] -macro_rules! apply_amortized_generic_list_or_array { - ($self:expr, $method:ident, $($args:expr),*) => { - match $self.dtype() { - #[cfg(feature = "dtype-array")] - DataType::Array(_, _) => $self.array().unwrap().apply_amortized_generic($($args),*), - DataType::List(_) => $self.list().unwrap().apply_amortized_generic($($args),*), - dt => panic!("not implemented for dtype {:?}", dt), - } - } -} - #[macro_export] macro_rules! apply_method_physical_integer { ($self:expr, $method:ident, $($args:expr),*) => { @@ -546,9 +538,9 @@ macro_rules! apply_method_physical_numeric { #[macro_export] macro_rules! df { ($($col_name:expr => $slice:expr), + $(,)?) => { - { - $crate::prelude::DataFrame::new(vec![$($crate::prelude::Series::new($col_name, $slice),)+]) - } + $crate::prelude::DataFrame::new(vec![ + $(<$crate::prelude::Series as $crate::prelude::NamedFrom::<_, _>>::new($col_name, $slice),)+ + ]) } } diff --git a/crates/polars-core/src/utils/series.rs b/crates/polars-core/src/utils/series.rs index 6a107d595a48a..9db543263f836 100644 --- a/crates/polars-core/src/utils/series.rs +++ b/crates/polars-core/src/utils/series.rs @@ -2,7 +2,7 @@ use crate::prelude::*; use crate::series::unstable::UnstableSeries; use crate::series::IsSorted; -/// Transform to physical type and coerce floating point and similar sized integer to a bit representation +/// Transform to physical type and coerce similar sized integer to a bit representation /// to reduce compiler bloat pub fn _to_physical_and_bit_repr(s: &[Series]) -> Vec { s.iter() @@ -11,8 +11,6 @@ pub fn _to_physical_and_bit_repr(s: &[Series]) -> Vec { match physical.dtype() { DataType::Int64 => physical.bit_repr_large().into_series(), DataType::Int32 => physical.bit_repr_small().into_series(), - DataType::Float32 => physical.bit_repr_small().into_series(), - DataType::Float64 => physical.bit_repr_large().into_series(), _ => physical.into_owned(), } }) @@ -50,6 +48,10 @@ pub fn handle_casting_failures(input: &Series, output: &Series) -> PolarsResult< - setting `strict=False` to set values that cannot be converted to `null`\n\ - using `str.strptime`, `str.to_date`, or `str.to_datetime` and providing a format string" }, + #[cfg(feature = "dtype-categorical")] + (DataType::String, DataType::Enum(_,_)) => { + "\n\nEnsure that all values in the input column are present in the categories of the enum datatype." + } _ => "", }; diff --git a/crates/polars-error/src/lib.rs b/crates/polars-error/src/lib.rs index 67f7429fce2ed..dbf7e92eac67f 100644 --- a/crates/polars-error/src/lib.rs +++ b/crates/polars-error/src/lib.rs @@ -274,7 +274,7 @@ macro_rules! polars_bail { macro_rules! polars_ensure { ($cond:expr, $($tt:tt)+) => { if !$cond { - polars_bail!($($tt)+); + $crate::polars_bail!($($tt)+); } }; } diff --git a/crates/polars-io/Cargo.toml b/crates/polars-io/Cargo.toml index f0573aee8ed17..43c3d677b8db5 100644 --- a/crates/polars-io/Cargo.toml +++ b/crates/polars-io/Cargo.toml @@ -80,7 +80,7 @@ dtype-i8 = ["polars-core/dtype-i8"] dtype-i16 = ["polars-core/dtype-i16"] dtype-categorical = ["polars-core/dtype-categorical"] dtype-date = ["polars-core/dtype-date", "polars-time/dtype-date"] -object = [] +object = ["polars-core/object"] dtype-datetime = [ "polars-core/dtype-datetime", "polars-core/temporal", @@ -90,6 +90,7 @@ dtype-datetime = [ timezones = [ "chrono-tz", "dtype-datetime", + "arrow/timezones", ] dtype-time = ["polars-core/dtype-time", "polars-core/temporal", "polars-time/dtype-time"] dtype-struct = ["polars-core/dtype-struct"] diff --git a/crates/polars-io/src/cloud/adaptors.rs b/crates/polars-io/src/cloud/adaptors.rs index 3d0316f60ab7e..49d74f0226edf 100644 --- a/crates/polars-io/src/cloud/adaptors.rs +++ b/crates/polars-io/src/cloud/adaptors.rs @@ -2,135 +2,17 @@ //! This is used, for example, by the [parquet2] crate. //! //! [parquet2]: https://crates.io/crates/parquet2 -use std::io::{self}; -use std::pin::Pin; + use std::sync::Arc; -use std::task::Poll; -use bytes::Bytes; -use futures::executor::block_on; -use futures::future::BoxFuture; -use futures::{AsyncRead, AsyncSeek, Future, TryFutureExt}; use object_store::path::Path; use object_store::{MultipartId, ObjectStore}; -use polars_error::{to_compute_err, PolarsError, PolarsResult}; +use polars_error::{to_compute_err, PolarsResult}; use tokio::io::{AsyncWrite, AsyncWriteExt}; -use super::*; +use super::CloudOptions; use crate::pl_async::get_runtime; -type OptionalFuture = Option>>; - -/// Adaptor to translate from AsyncSeek and AsyncRead to the object_store get_range API. -pub struct CloudReader { - // The current position in the stream, it is set by seeking and updated by reading bytes. - pos: u64, - // The total size of the object is required when seeking from the end of the file. - length: Option, - // Hold an reference to the store in a thread safe way. - object_store: Arc, - // The path in the object_store of the current object being read. - path: Path, - // If a read is pending then `active` will point to its future. - active: OptionalFuture, -} - -impl CloudReader { - pub fn new(length: Option, object_store: Arc, path: Path) -> Self { - Self { - pos: 0, - length, - object_store, - path, - active: None, - } - } - - /// For each read request we create a new future. - async fn read_operation( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - length: usize, - ) -> std::task::Poll> { - let start = self.pos as usize; - - // If we already have a future just poll it. - if let Some(fut) = self.active.as_mut() { - return Future::poll(fut.as_mut(), cx); - } - - // Create the future. - let future = { - let path = self.path.clone(); - let object_store = self.object_store.clone(); - // Use an async move block to get our owned objects. - async move { - object_store - .get_range(&path, start..start + length) - .map_err(|e| { - std::io::Error::new( - std::io::ErrorKind::Other, - format!("object store error {e:?}"), - ) - }) - .await - } - }; - // Prepare for next read. - self.pos += length as u64; - - let mut future = Box::pin(future); - - // Need to poll it once to get the pump going. - let polled = Future::poll(future.as_mut(), cx); - - // Save for next time. - self.active = Some(future); - polled - } -} - -impl AsyncRead for CloudReader { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut [u8], - ) -> std::task::Poll> { - // Use block_on in order to get the future result in this thread and copy the data in the output buffer. - // With this approach we keep ownership of the buffer and we don't have to pass it to the future runtime. - match block_on(self.read_operation(cx, buf.len())) { - Poll::Ready(Ok(bytes)) => { - buf.copy_from_slice(bytes.as_ref()); - Poll::Ready(Ok(bytes.len())) - }, - Poll::Ready(Err(e)) => Poll::Ready(Err(e)), - Poll::Pending => Poll::Pending, - } - } -} - -impl AsyncSeek for CloudReader { - fn poll_seek( - mut self: Pin<&mut Self>, - _: &mut std::task::Context<'_>, - pos: io::SeekFrom, - ) -> std::task::Poll> { - match pos { - io::SeekFrom::Start(pos) => self.pos = pos, - io::SeekFrom::End(pos) => { - let length = self.length.ok_or::(io::Error::new( - std::io::ErrorKind::Other, - "Cannot seek from end of stream when length is unknown.", - ))?; - self.pos = (length as i64 + pos) as u64 - }, - io::SeekFrom::Current(pos) => self.pos = (self.pos as i64 + pos) as u64, - }; - self.active = None; - std::task::Poll::Ready(Ok(self.pos)) - } -} - /// Adaptor which wraps the asynchronous interface of [ObjectStore::put_multipart](https://docs.rs/object_store/latest/object_store/trait.ObjectStore.html#tymethod.put_multipart) /// exposing a synchronous interface which implements `std::io::Write`. /// @@ -157,16 +39,13 @@ impl CloudWriter { object_store: Arc, path: Path, ) -> PolarsResult { - let build_result = Self::build_writer(&object_store, &path).await; - match build_result { - Err(error) => Err(PolarsError::from(error)), - Ok((multipart_id, writer)) => Ok(CloudWriter { - object_store, - path, - multipart_id, - writer, - }), - } + let (multipart_id, writer) = Self::build_writer(&object_store, &path).await?; + Ok(CloudWriter { + object_store, + path, + multipart_id, + writer, + }) } /// Constructs a new CloudWriter from a path and an optional set of CloudOptions. @@ -226,9 +105,8 @@ impl Drop for CloudWriter { #[cfg(feature = "csv")] #[cfg(test)] mod tests { - use object_store::ObjectStore; use polars_core::df; - use polars_core::prelude::{DataFrame, NamedFrom}; + use polars_core::prelude::DataFrame; use super::*; diff --git a/crates/polars-io/src/cloud/glob.rs b/crates/polars-io/src/cloud/glob.rs index f59a236b09666..4d40f31c9d652 100644 --- a/crates/polars-io/src/cloud/glob.rs +++ b/crates/polars-io/src/cloud/glob.rs @@ -1,13 +1,13 @@ -use arrow::legacy::error::polars_bail; use futures::future::ready; use futures::{StreamExt, TryStreamExt}; use object_store::path::Path; use polars_core::error::to_compute_err; -use polars_core::prelude::{polars_ensure, polars_err, PolarsError, PolarsResult}; +use polars_core::prelude::{polars_ensure, polars_err}; +use polars_error::{PolarsError, PolarsResult}; use regex::Regex; use url::Url; -use super::*; +use super::CloudOptions; const DELIMITER: char = '/'; diff --git a/crates/polars-io/src/cloud/mod.rs b/crates/polars-io/src/cloud/mod.rs index 4c46260de21ff..e8841a099df23 100644 --- a/crates/polars-io/src/cloud/mod.rs +++ b/crates/polars-io/src/cloud/mod.rs @@ -1,29 +1,24 @@ //! Interface with cloud storage through the object_store crate. #[cfg(feature = "cloud")] -use std::borrow::Cow; +mod adaptors; #[cfg(feature = "cloud")] -use std::sync::Arc; +pub use adaptors::*; #[cfg(feature = "cloud")] -use object_store::local::LocalFileSystem; +mod polars_object_store; #[cfg(feature = "cloud")] -use object_store::ObjectStore; -#[cfg(feature = "cloud")] -use polars_core::prelude::{polars_bail, PolarsError, PolarsResult}; +pub use polars_object_store::*; -#[cfg(feature = "cloud")] -mod adaptors; #[cfg(feature = "cloud")] mod glob; #[cfg(feature = "cloud")] -mod object_store_setup; -pub mod options; +pub use glob::*; #[cfg(feature = "cloud")] -pub use adaptors::*; -#[cfg(feature = "cloud")] -pub use glob::*; +mod object_store_setup; #[cfg(feature = "cloud")] pub use object_store_setup::*; + +pub mod options; pub use options::*; diff --git a/crates/polars-io/src/cloud/object_store_setup.rs b/crates/polars-io/src/cloud/object_store_setup.rs index 64860be341830..1e8e77c3d7cd4 100644 --- a/crates/polars-io/src/cloud/object_store_setup.rs +++ b/crates/polars-io/src/cloud/object_store_setup.rs @@ -1,17 +1,20 @@ +use std::sync::Arc; + +use object_store::local::LocalFileSystem; +use object_store::ObjectStore; use once_cell::sync::Lazy; -pub use options::*; -use polars_error::to_compute_err; +use polars_error::{polars_bail, to_compute_err, PolarsError, PolarsResult}; +use polars_utils::aliases::PlHashMap; use tokio::sync::RwLock; +use url::Url; -use super::*; - -type CacheKey = (String, Option); +use super::{parse_url, CloudLocation, CloudOptions, CloudType}; -/// A very simple cache that only stores a single object-store. -/// This greatly reduces the query times as multiple object stores (when reading many small files) +/// Object stores must be cached. Every object-store will do DNS lookups and /// get rate limited when querying the DNS (can take up to 5s). +/// Other reasons are connection pools that must be shared between as much as possible. #[allow(clippy::type_complexity)] -static OBJECT_STORE_CACHE: Lazy)>>> = +static OBJECT_STORE_CACHE: Lazy>>> = Lazy::new(Default::default); type BuildResult = PolarsResult<(CloudLocation, Arc)>; @@ -24,28 +27,39 @@ fn err_missing_feature(feature: &str, scheme: &str) -> BuildResult { ); } +/// Get the key of a url for object store registration. +/// The credential info will be removed +fn url_to_key(url: &Url) -> String { + format!( + "{}://{}", + url.scheme(), + &url[url::Position::BeforeHost..url::Position::AfterPort], + ) +} + /// Build an [`ObjectStore`] based on the URL and passed in url. Return the cloud location and an implementation of the object store. -pub async fn build_object_store(url: &str, options: Option<&CloudOptions>) -> BuildResult { +pub async fn build_object_store( + url: &str, + #[cfg_attr( + not(any(feature = "aws", feature = "gcp", feature = "azure")), + allow(unused_variables) + )] + options: Option<&CloudOptions>, +) -> BuildResult { let parsed = parse_url(url).map_err(to_compute_err)?; let cloud_location = CloudLocation::from_url(&parsed)?; - let options = options.cloned(); - let key = (url.to_string(), options); + let key = url_to_key(&parsed); { let cache = OBJECT_STORE_CACHE.read().await; - if let Some((stored_key, store)) = cache.as_ref() { - if stored_key == &key { - return Ok((cloud_location, store.clone())); - } + if let Some(store) = cache.get(&key) { + return Ok((cloud_location, store.clone())); } } - let options = key - .1 - .as_ref() - .map(Cow::Borrowed) - .unwrap_or_else(|| Cow::Owned(Default::default())); + #[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))] + let options = options.map(std::borrow::Cow::Borrowed).unwrap_or_default(); let cloud_type = CloudType::from_url(&parsed)?; let store = match cloud_type { @@ -88,7 +102,7 @@ pub async fn build_object_store(url: &str, options: Option<&CloudOptions>) -> Bu { let store = object_store::http::HttpBuilder::new() .with_url(url) - .with_client_options(get_client_options()) + .with_client_options(super::get_client_options()) .build()?; Ok::<_, PolarsError>(Arc::new(store) as Arc) } @@ -98,6 +112,10 @@ pub async fn build_object_store(url: &str, options: Option<&CloudOptions>) -> Bu }, }?; let mut cache = OBJECT_STORE_CACHE.write().await; - *cache = Some((key, store.clone())); + // Clear the cache if we surpass a certain amount of buckets. Don't expect that to happen. + if cache.len() > 512 { + cache.clear() + } + cache.insert(key, store.clone()); Ok((cloud_location, store)) } diff --git a/crates/polars-io/src/cloud/options.rs b/crates/polars-io/src/cloud/options.rs index 1df835d5d3b0b..475de166ae022 100644 --- a/crates/polars-io/src/cloud/options.rs +++ b/crates/polars-io/src/cloud/options.rs @@ -18,13 +18,10 @@ use object_store::gcp::GoogleCloudStorageBuilder; pub use object_store::gcp::GoogleConfigKey; #[cfg(any(feature = "aws", feature = "gcp", feature = "azure", feature = "http"))] use object_store::ClientOptions; -#[cfg(feature = "cloud")] -use object_store::ObjectStore; #[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))] use object_store::{BackoffConfig, RetryConfig}; #[cfg(feature = "aws")] use once_cell::sync::Lazy; -use polars_core::error::{PolarsError, PolarsResult}; use polars_error::*; #[cfg(feature = "aws")] use polars_utils::cache::FastFixedCache; @@ -172,9 +169,9 @@ pub(super) fn get_client_options() -> ClientOptions { // We set request timeout super high as the timeout isn't reset at ACK, // but starts from the moment we start downloading a body. // https://docs.rs/reqwest/latest/reqwest/struct.ClientBuilder.html#method.timeout - .with_timeout(std::time::Duration::from_secs(60 * 5)) - // Concurrency can increase connection latency, so also set high. - .with_connect_timeout(std::time::Duration::from_secs(30)) + .with_timeout_disabled() + // Concurrency can increase connection latency, so set to None, similar to default. + .with_connect_timeout_disabled() .with_allow_http(true) } @@ -227,9 +224,9 @@ impl CloudOptions { self } - /// Build the [`ObjectStore`] implementation for AWS. + /// Build the [`object_store::ObjectStore`] implementation for AWS. #[cfg(feature = "aws")] - pub async fn build_aws(&self, url: &str) -> PolarsResult { + pub async fn build_aws(&self, url: &str) -> PolarsResult { let options = self.aws.as_ref(); let mut builder = AmazonS3Builder::from_env().with_url(url); if let Some(options) = options { @@ -330,9 +327,9 @@ impl CloudOptions { self } - /// Build the [`ObjectStore`] implementation for Azure. + /// Build the [`object_store::ObjectStore`] implementation for Azure. #[cfg(feature = "azure")] - pub fn build_azure(&self, url: &str) -> PolarsResult { + pub fn build_azure(&self, url: &str) -> PolarsResult { let options = self.azure.as_ref(); let mut builder = MicrosoftAzureBuilder::from_env(); if let Some(options) = options { @@ -364,9 +361,9 @@ impl CloudOptions { self } - /// Build the [`ObjectStore`] implementation for GCP. + /// Build the [`object_store::ObjectStore`] implementation for GCP. #[cfg(feature = "gcp")] - pub fn build_gcp(&self, url: &str) -> PolarsResult { + pub fn build_gcp(&self, url: &str) -> PolarsResult { let options = self.gcp.as_ref(); let mut builder = GoogleCloudStorageBuilder::from_env(); if let Some(options) = options { diff --git a/crates/polars-io/src/cloud/polars_object_store.rs b/crates/polars-io/src/cloud/polars_object_store.rs new file mode 100644 index 0000000000000..e22b658520c10 --- /dev/null +++ b/crates/polars-io/src/cloud/polars_object_store.rs @@ -0,0 +1,61 @@ +use std::ops::Range; +use std::sync::Arc; + +use bytes::Bytes; +use object_store::path::Path; +use object_store::{ObjectMeta, ObjectStore}; +use polars_error::{to_compute_err, PolarsResult}; + +use crate::pl_async::{ + tune_with_concurrency_budget, with_concurrency_budget, MAX_BUDGET_PER_REQUEST, +}; + +/// Polars specific wrapper for `Arc` that limits the number of +/// concurrent requests for the entire application. +#[derive(Debug, Clone)] +pub struct PolarsObjectStore(Arc); + +impl PolarsObjectStore { + pub fn new(store: Arc) -> Self { + Self(store) + } + + pub async fn get(&self, path: &Path) -> PolarsResult { + tune_with_concurrency_budget(1, || async { + self.0 + .get(path) + .await + .map_err(to_compute_err)? + .bytes() + .await + .map_err(to_compute_err) + }) + .await + } + + pub async fn get_range(&self, path: &Path, range: Range) -> PolarsResult { + tune_with_concurrency_budget(1, || self.0.get_range(path, range)) + .await + .map_err(to_compute_err) + } + + pub async fn get_ranges( + &self, + path: &Path, + ranges: &[Range], + ) -> PolarsResult> { + tune_with_concurrency_budget( + (ranges.len() as u32).clamp(0, MAX_BUDGET_PER_REQUEST as u32), + || self.0.get_ranges(path, ranges), + ) + .await + .map_err(to_compute_err) + } + + /// Fetch the metadata of the parquet file, do not memoize it. + pub async fn head(&self, path: &Path) -> PolarsResult { + with_concurrency_budget(1, || self.0.head(path)) + .await + .map_err(to_compute_err) + } +} diff --git a/crates/polars-io/src/csv/buffer.rs b/crates/polars-io/src/csv/buffer.rs index 59852c6e47fa3..8e96b6e1b4c66 100644 --- a/crates/polars-io/src/csv/buffer.rs +++ b/crates/polars-io/src/csv/buffer.rs @@ -189,10 +189,8 @@ impl ParsedBuffer for Utf8Field { return Ok(()); } - let parse_result = validate_utf8(bytes); - // note that one branch writes without updating the length, so we must do that later. - let bytes = if needs_escaping { + let escaped_bytes = if needs_escaping { self.scratch.clear(); self.scratch.reserve(bytes.len()); polars_ensure!(bytes.len() > 1, ComputeError: "invalid csv file\n\nField `{}` is not properly escaped.", std::str::from_utf8(bytes).map_err(to_compute_err)?); @@ -209,20 +207,29 @@ impl ParsedBuffer for Utf8Field { bytes }; + // It is important that this happens after escaping, as invalid escaped string can produce + // invalid utf8. + let parse_result = validate_utf8(escaped_bytes); + match parse_result { true => { - let value = unsafe { std::str::from_utf8_unchecked(bytes) }; + let value = unsafe { std::str::from_utf8_unchecked(escaped_bytes) }; self.mutable.push_value(value) }, false => { if matches!(self.encoding, CsvEncoding::LossyUtf8) { // TODO! do this without allocating - let s = String::from_utf8_lossy(bytes); + let s = String::from_utf8_lossy(escaped_bytes); self.mutable.push_value(s.as_ref()) } else if ignore_errors { self.mutable.push_null() } else { - polars_bail!(ComputeError: "invalid utf-8 sequence"); + // If field before escaping is valid utf8, the escaping is incorrect. + if needs_escaping && validate_utf8(bytes) { + polars_bail!(ComputeError: "string field is not properly escaped"); + } else { + polars_bail!(ComputeError: "invalid utf-8 sequence"); + } } }, } diff --git a/crates/polars-io/src/csv/mod.rs b/crates/polars-io/src/csv/mod.rs index 4eaf0efbd73ce..fba65a0f719f3 100644 --- a/crates/polars-io/src/csv/mod.rs +++ b/crates/polars-io/src/csv/mod.rs @@ -54,6 +54,7 @@ use std::fs::File; use std::io::Write; use std::path::PathBuf; +pub use parser::count_rows; use polars_core::prelude::*; #[cfg(feature = "temporal")] use polars_time::prelude::*; diff --git a/crates/polars-io/src/csv/parser.rs b/crates/polars-io/src/csv/parser.rs index afa566d99cf2c..1be5616f3bfca 100644 --- a/crates/polars-io/src/csv/parser.rs +++ b/crates/polars-io/src/csv/parser.rs @@ -1,11 +1,78 @@ +use std::path::PathBuf; + use memchr::memchr2_iter; use num_traits::Pow; use polars_core::prelude::*; +use polars_core::POOL; +use polars_utils::index::Bounded; +use polars_utils::slice::GetSaferUnchecked; +use rayon::prelude::*; use super::buffer::*; use crate::csv::read::NullValuesCompiled; use crate::csv::splitfields::SplitFields; +use crate::csv::utils::get_file_chunks; use crate::csv::CommentPrefix; +use crate::utils::get_reader_bytes; + +/// Read the number of rows without parsing columns +/// useful for count(*) queries +pub fn count_rows( + path: &PathBuf, + separator: u8, + quote_char: Option, + comment_prefix: Option<&CommentPrefix>, + eol_char: u8, + has_header: bool, +) -> PolarsResult { + let mut reader = polars_utils::open_file(path)?; + let reader_bytes = get_reader_bytes(&mut reader)?; + const MIN_ROWS_PER_THREAD: usize = 1024; + let max_threads = POOL.current_num_threads(); + + // Determine if parallelism is beneficial and how many threads + let n_threads = get_line_stats( + &reader_bytes, + MIN_ROWS_PER_THREAD, + eol_char, + None, + separator, + quote_char, + ) + .map(|(mean, std)| { + let n_rows = (reader_bytes.len() as f32 / (mean - 0.01 * std)) as usize; + (n_rows / MIN_ROWS_PER_THREAD).clamp(1, max_threads) + }) + .unwrap_or(1); + + let file_chunks = get_file_chunks( + &reader_bytes, + n_threads, + None, + separator, + quote_char, + eol_char, + ); + + let iter = file_chunks.into_par_iter().map(|(start, stop)| { + let local_bytes = &reader_bytes[start..stop]; + let row_iterator = SplitLines::new(local_bytes, quote_char.unwrap_or(b'"'), eol_char); + if comment_prefix.is_some() { + Ok(row_iterator + .filter(|line| !line.is_empty() && !is_comment_line(line, comment_prefix)) + .count()) + } else { + Ok(row_iterator.count()) + } + }); + + let count_result: PolarsResult = POOL.install(|| iter.sum()); + + match count_result { + Ok(val) => Ok(val - (has_header as usize)), + Err(err) => Err(err), + } +} /// Skip the utf-8 Byte Order Mark. /// credits to csv-core @@ -183,7 +250,7 @@ pub(crate) fn get_line_stats( bytes: &[u8], n_lines: usize, eol_char: u8, - expected_fields: usize, + expected_fields: Option, separator: u8, quote_char: Option, ) -> Option<(f32, f32)> { @@ -199,7 +266,7 @@ pub(crate) fn get_line_stats( bytes_trunc = &bytes[offset..]; let pos = next_line_position( bytes_trunc, - Some(expected_fields), + expected_fields, separator, quote_char, eol_char, @@ -407,7 +474,9 @@ pub(super) fn parse_lines( match iter.next() { // end of line None => { - bytes = &bytes[std::cmp::min(read_sol, bytes.len())..]; + bytes = unsafe { + bytes.get_unchecked_release(std::cmp::min(read_sol, bytes.len())..) + }; break; }, Some((mut field, needs_escaping)) => { @@ -419,8 +488,11 @@ pub(super) fn parse_lines( if idx == next_projected as u32 { // the iterator is finished when it encounters a `\n` // this could be preceded by a '\r' - if field_len > 0 && field[field_len - 1] == b'\r' { - field = &field[..field_len - 1]; + unsafe { + if field_len > 0 && *field.get_unchecked_release(field_len - 1) == b'\r' + { + field = field.get_unchecked_release(..field_len - 1); + } } debug_assert!(processed_fields < buffers.len()); @@ -433,7 +505,7 @@ pub(super) fn parse_lines( // if we have null values argument, check if this field equal null value if let Some(null_values) = null_values { let field = if needs_escaping && !field.is_empty() { - &field[1..field.len() - 1] + unsafe { field.get_unchecked_release(1..field.len() - 1) } } else { field }; @@ -486,7 +558,7 @@ pub(super) fn parse_lines( Consider setting 'truncate_ragged_lines={}'."#, polars_error::constants::TRUE) } let bytes_rem = skip_this_line( - &bytes[read_sol - 1..], + unsafe { bytes.get_unchecked_release(read_sol - 1..) }, quote_char, eol_char, ); diff --git a/crates/polars-io/src/csv/read.rs b/crates/polars-io/src/csv/read.rs index 15d2e02c81d8c..6168fa620bb95 100644 --- a/crates/polars-io/src/csv/read.rs +++ b/crates/polars-io/src/csv/read.rs @@ -73,7 +73,8 @@ impl NullValuesCompiled { } } - /// Safety + /// # Safety + /// /// The caller must ensure that `index` is in bounds pub(super) unsafe fn is_null(&self, field: &[u8], index: usize) -> bool { use NullValuesCompiled::*; @@ -680,6 +681,8 @@ where #[cfg(feature = "temporal")] fn parse_dates(mut df: DataFrame, fixed_schema: &Schema) -> DataFrame { + use polars_core::POOL; + let cols = unsafe { std::mem::take(df.get_columns_mut()) } .into_par_iter() .map(|s| { @@ -699,8 +702,8 @@ fn parse_dates(mut df: DataFrame, fixed_schema: &Schema) -> DataFrame { }, _ => s, } - }) - .collect::>(); + }); + let cols = POOL.install(|| cols.collect::>()); - DataFrame::new_no_checks(cols) + unsafe { DataFrame::new_no_checks(cols) } } diff --git a/crates/polars-io/src/csv/read_impl/mod.rs b/crates/polars-io/src/csv/read_impl/mod.rs index 9d667e0660244..05d88ed89036a 100644 --- a/crates/polars-io/src/csv/read_impl/mod.rs +++ b/crates/polars-io/src/csv/read_impl/mod.rs @@ -3,7 +3,6 @@ mod batched_read; use std::fmt; use std::ops::Deref; -use std::sync::Arc; pub use batched_mmap::*; pub use batched_read::*; @@ -73,7 +72,7 @@ pub(crate) fn cast_columns( } }) .collect::>>()?; - *df = DataFrame::new_no_checks(cols) + *df = unsafe { DataFrame::new_no_checks(cols) } } else { // cast to the original dtypes in the schema for fld in to_cast { @@ -346,7 +345,7 @@ impl<'a> CoreReader<'a> { bytes, self.sample_size, self.eol_char, - self.schema.len(), + Some(self.schema.len()), self.separator, self.quote_char, ) { @@ -426,7 +425,7 @@ impl<'a> CoreReader<'a> { let chunks = get_file_chunks( bytes, n_file_chunks, - self.schema.len(), + Some(self.schema.len()), self.separator, self.quote_char, self.eol_char, @@ -533,12 +532,11 @@ impl<'a> CoreReader<'a> { &self.schema, )?; - let mut local_df = DataFrame::new_no_checks( - buffers - .into_iter() - .map(|buf| buf.into_series()) - .collect::>()?, - ); + let columns = buffers + .into_iter() + .map(|buf| buf.into_series()) + .collect::>()?; + let mut local_df = unsafe { DataFrame::new_no_checks(columns) }; let current_row_count = local_df.height() as IdxSize; if let Some(rc) = &self.row_index { local_df.with_row_index_mut(&rc.name, Some(rc.offset)); @@ -637,12 +635,11 @@ impl<'a> CoreReader<'a> { self.schema.as_ref(), )?; - DataFrame::new_no_checks( - buffers - .into_iter() - .map(|buf| buf.into_series()) - .collect::>()?, - ) + let columns = buffers + .into_iter() + .map(|buf| buf.into_series()) + .collect::>()?; + unsafe { DataFrame::new_no_checks(columns) } }; cast_columns(&mut df, &self.to_cast, false, self.ignore_errors)?; @@ -732,10 +729,9 @@ fn read_chunk( )?; } - Ok(DataFrame::new_no_checks( - buffers - .into_iter() - .map(|buf| buf.into_series()) - .collect::>()?, - )) + let columns = buffers + .into_iter() + .map(|buf| buf.into_series()) + .collect::>()?; + Ok(unsafe { DataFrame::new_no_checks(columns) }) } diff --git a/crates/polars-io/src/csv/utils.rs b/crates/polars-io/src/csv/utils.rs index 1b1da32b74f19..64d98a6ff3ee7 100644 --- a/crates/polars-io/src/csv/utils.rs +++ b/crates/polars-io/src/csv/utils.rs @@ -4,7 +4,6 @@ use std::io::Read; use std::mem::MaybeUninit; use polars_core::config::verbose; -use polars_core::datatypes::PlHashSet; use polars_core::prelude::*; #[cfg(feature = "polars-time")] use polars_time::chunkedarray::string::infer as date_infer; @@ -25,7 +24,7 @@ use crate::utils::{BOOLEAN_RE, FLOAT_RE, INTEGER_RE}; pub(crate) fn get_file_chunks( bytes: &[u8], n_chunks: usize, - expected_fields: usize, + expected_fields: Option, separator: u8, quote_char: Option, eol_char: u8, @@ -43,7 +42,7 @@ pub(crate) fn get_file_chunks( let end_pos = match next_line_position( &bytes[search_pos..], - Some(expected_fields), + expected_fields, separator, quote_char, eol_char, @@ -684,7 +683,11 @@ mod test { let s = std::fs::read_to_string(path).unwrap(); let bytes = s.as_bytes(); // can be within -1 / +1 bounds. - assert!((get_file_chunks(bytes, 10, 4, b',', None, b'\n').len() as i32 - 10).abs() <= 1); - assert!((get_file_chunks(bytes, 8, 4, b',', None, b'\n').len() as i32 - 8).abs() <= 1); + assert!( + (get_file_chunks(bytes, 10, Some(4), b',', None, b'\n').len() as i32 - 10).abs() <= 1 + ); + assert!( + (get_file_chunks(bytes, 8, Some(4), b',', None, b'\n').len() as i32 - 8).abs() <= 1 + ); } } diff --git a/crates/polars-io/src/ipc/ipc_file.rs b/crates/polars-io/src/ipc/ipc_file.rs index f57d1132b7dd3..ec0cf8cf2ec0a 100644 --- a/crates/polars-io/src/ipc/ipc_file.rs +++ b/crates/polars-io/src/ipc/ipc_file.rs @@ -33,7 +33,6 @@ //! assert!(df.equals(&df_read)); //! ``` use std::io::{Read, Seek}; -use std::sync::Arc; use arrow::datatypes::ArrowSchemaRef; use arrow::io::ipc::read; @@ -91,16 +90,6 @@ fn check_mmap_err(err: PolarsError) -> PolarsResult<()> { } impl IpcReader { - #[doc(hidden)] - /// A very bad estimate of the number of rows - /// This estimation will be entirely off if the file is compressed. - /// And will be varying off depending on the data types. - pub fn _num_rows(&mut self) -> PolarsResult { - let metadata = self.get_metadata()?; - let n_cols = metadata.schema.fields.len(); - // this magic number 10 is computed from the yellow trip dataset - Ok((metadata.size as usize) / n_cols / 10) - } fn get_metadata(&mut self) -> PolarsResult<&read::FileMetadata> { if self.metadata.is_none() { let metadata = read::read_file_metadata(&mut self.reader)?; @@ -165,6 +154,13 @@ impl IpcReader { let rechunk = self.rechunk; let metadata = read::read_file_metadata(&mut self.reader)?; + // NOTE: For some code paths this already happened. See + // https://github.com/pola-rs/polars/pull/14984#discussion_r1520125000 + // where this was introduced. + if let Some(columns) = &self.columns { + self.projection = Some(columns_to_projection(columns, &metadata.schema)?); + } + let schema = if let Some(projection) = &self.projection { Arc::new(apply_projection(&metadata.schema, projection)) } else { diff --git a/crates/polars-io/src/ipc/ipc_reader_async.rs b/crates/polars-io/src/ipc/ipc_reader_async.rs new file mode 100644 index 0000000000000..1447881817ab9 --- /dev/null +++ b/crates/polars-io/src/ipc/ipc_reader_async.rs @@ -0,0 +1,205 @@ +use std::sync::Arc; + +use arrow::io::ipc::read::{get_row_count, FileMetadata, OutOfSpecKind}; +use object_store::path::Path; +use object_store::ObjectMeta; +use polars_core::datatypes::IDX_DTYPE; +use polars_core::frame::DataFrame; +use polars_core::schema::Schema; +use polars_error::{polars_bail, polars_err, to_compute_err, PolarsResult}; + +use crate::cloud::{build_object_store, CloudLocation, CloudOptions, PolarsObjectStore}; +use crate::predicates::PhysicalIoExpr; +use crate::prelude::{materialize_projection, IpcReader}; +use crate::RowIndex; + +/// An Arrow IPC reader implemented on top of PolarsObjectStore. +pub struct IpcReaderAsync { + store: PolarsObjectStore, + path: Path, +} + +#[derive(Default, Clone)] +pub struct IpcReadOptions { + // Names of the columns to include in the output. + projection: Option>, + + // The maximum number of rows to include in the output. + row_limit: Option, + + // Include a column with the row number under the provided name starting at the provided index. + row_index: Option, + + // Only include rows that pass this predicate. + predicate: Option>, +} + +impl IpcReadOptions { + pub fn with_projection(mut self, indices: impl Into>>) -> Self { + self.projection = indices.into(); + self + } + + pub fn with_row_limit(mut self, row_limit: impl Into>) -> Self { + self.row_limit = row_limit.into(); + self + } + + pub fn with_row_index(mut self, row_index: impl Into>) -> Self { + self.row_index = row_index.into(); + self + } + + pub fn with_predicate(mut self, predicate: impl Into>>) -> Self { + self.predicate = predicate.into(); + self + } +} + +impl IpcReaderAsync { + pub async fn from_uri( + uri: &str, + cloud_options: Option<&CloudOptions>, + ) -> PolarsResult { + let ( + CloudLocation { + prefix, expansion, .. + }, + store, + ) = build_object_store(uri, cloud_options).await?; + + let path = { + // Any wildcards should already have been resolved here. Without this assertion they would + // be ignored. + debug_assert!(expansion.is_none(), "path should not contain wildcards"); + Path::from_url_path(prefix).map_err(to_compute_err)? + }; + + Ok(Self { + store: PolarsObjectStore::new(store), + path, + }) + } + + async fn object_metadata(&self) -> PolarsResult { + self.store.head(&self.path).await + } + + async fn file_size(&self) -> PolarsResult { + Ok(self.object_metadata().await?.size) + } + + pub async fn metadata(&self) -> PolarsResult { + let file_size = self.file_size().await?; + + // TODO: Do a larger request and hope that the entire footer is contained within it to save one round-trip. + let footer_metadata = + self.store + .get_range( + &self.path, + file_size.checked_sub(FOOTER_METADATA_SIZE).ok_or_else(|| { + to_compute_err("ipc file size is smaller than the minimum") + })?..file_size, + ) + .await?; + + let footer_size = deserialize_footer_metadata( + footer_metadata + .as_ref() + .try_into() + .map_err(to_compute_err)?, + )?; + + let footer = self + .store + .get_range( + &self.path, + file_size + .checked_sub(FOOTER_METADATA_SIZE + footer_size) + .ok_or_else(|| { + to_compute_err("invalid ipc footer metadata: footer size too large") + })?..file_size, + ) + .await?; + + arrow::io::ipc::read::deserialize_footer( + footer.as_ref(), + footer_size.try_into().map_err(to_compute_err)?, + ) + } + + pub async fn data( + &self, + metadata: Option<&FileMetadata>, + options: IpcReadOptions, + verbose: bool, + ) -> PolarsResult { + // TODO: Only download what is needed rather than the entire file by + // making use of the projection, row limit, predicate and such. + let bytes = self.store.get(&self.path).await?; + + let projection = match options.projection.as_deref() { + Some(projection) => { + fn prepare_schema(mut schema: Schema, row_index: Option<&RowIndex>) -> Schema { + if let Some(rc) = row_index { + let _ = schema.insert_at_index(0, rc.name.as_str().into(), IDX_DTYPE); + } + schema + } + + // Retrieve the metadata for the schema so we can map column names to indices. + let fetched_metadata; + let metadata = if let Some(metadata) = metadata { + metadata + } else { + // This branch is happens when _metadata is None, which can happen if we Deserialize the execution plan. + fetched_metadata = self.metadata().await?; + &fetched_metadata + }; + + let schema = prepare_schema((&metadata.schema).into(), options.row_index.as_ref()); + + let hive_partitions = None; + + materialize_projection( + Some(projection), + &schema, + hive_partitions, + options.row_index.is_some(), + ) + }, + None => None, + }; + + let reader = + as crate::SerReader<_>>::new(std::io::Cursor::new(bytes.as_ref())) + .with_row_index(options.row_index) + .with_n_rows(options.row_limit) + .with_projection(projection); + reader.finish_with_scan_ops(options.predicate, verbose) + } + + pub async fn count_rows(&self, _metadata: Option<&FileMetadata>) -> PolarsResult { + // TODO: Only download what is needed rather than the entire file by + // making use of the projection, row limit, predicate and such. + let bytes = self.store.get(&self.path).await?; + get_row_count(&mut std::io::Cursor::new(bytes.as_ref())) + } +} + +const FOOTER_METADATA_SIZE: usize = 10; + +// TODO: Move to polars-arrow and deduplicate parsing of footer metadata in +// sync and async readers. +fn deserialize_footer_metadata(bytes: [u8; FOOTER_METADATA_SIZE]) -> PolarsResult { + let footer_size: usize = + i32::from_le_bytes(bytes[0..4].try_into().unwrap_or_else(|_| unreachable!())) + .try_into() + .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?; + + if &bytes[4..] != b"ARROW1" { + polars_bail!(oos = OutOfSpecKind::InvalidFooter); + } + + Ok(footer_size) +} diff --git a/crates/polars-io/src/ipc/ipc_stream.rs b/crates/polars-io/src/ipc/ipc_stream.rs index a36ae625ed41a..e748f670ad3b4 100644 --- a/crates/polars-io/src/ipc/ipc_stream.rs +++ b/crates/polars-io/src/ipc/ipc_stream.rs @@ -166,20 +166,14 @@ where self.projection = Some(prj); } - let sorted_projection = self.projection.clone().map(|mut proj| { - proj.sort_unstable(); - proj - }); - - let schema = if let Some(projection) = &sorted_projection { + let schema = if let Some(projection) = &self.projection { apply_projection(&metadata.schema, projection) } else { metadata.schema.clone() }; - let include_row_index = self.row_index.is_some(); let ipc_reader = - read::StreamReader::new(&mut self.reader, metadata.clone(), sorted_projection); + read::StreamReader::new(&mut self.reader, metadata.clone(), self.projection); finish_reader( ipc_reader, rechunk, @@ -188,35 +182,6 @@ where &schema, self.row_index, ) - .map(|df| fix_column_order(df, self.projection, include_row_index)) - } -} - -fn fix_column_order( - df: DataFrame, - projection: Option>, - include_row_index: bool, -) -> DataFrame { - if let Some(proj) = projection { - let offset = usize::from(include_row_index); - let mut args = (0..proj.len()).zip(proj).collect::>(); - // first el of tuple is argument index - // second el is the projection index - args.sort_unstable_by_key(|tpl| tpl.1); - let cols = df.get_columns(); - - let iter = args.iter().map(|tpl| cols[tpl.0 + offset].clone()); - let cols = if include_row_index { - let mut new_cols = vec![df.get_columns()[0].clone()]; - new_cols.extend(iter); - new_cols - } else { - iter.collect() - }; - - DataFrame::new_no_checks(cols) - } else { - df } } diff --git a/crates/polars-io/src/ipc/mod.rs b/crates/polars-io/src/ipc/mod.rs index 1366aa84324fe..813fc7f5df78d 100644 --- a/crates/polars-io/src/ipc/mod.rs +++ b/crates/polars-io/src/ipc/mod.rs @@ -16,3 +16,8 @@ pub use ipc_file::IpcReader; #[cfg(feature = "ipc_streaming")] pub use ipc_stream::*; pub use write::{BatchedWriter, IpcCompression, IpcWriter, IpcWriterOption}; + +#[cfg(feature = "cloud")] +mod ipc_reader_async; +#[cfg(feature = "cloud")] +pub use ipc_reader_async::*; diff --git a/crates/polars-io/src/json/infer.rs b/crates/polars-io/src/json/infer.rs index 578d9bc8fadf4..0019f98fb5f14 100644 --- a/crates/polars-io/src/json/infer.rs +++ b/crates/polars-io/src/json/infer.rs @@ -1,5 +1,3 @@ -use simd_json::value::BorrowedValue; - use super::*; pub(crate) fn json_values_to_supertype( diff --git a/crates/polars-io/src/json/mod.rs b/crates/polars-io/src/json/mod.rs index da7360985dc5a..2e1ce26924708 100644 --- a/crates/polars-io/src/json/mod.rs +++ b/crates/polars-io/src/json/mod.rs @@ -64,12 +64,10 @@ //! pub(crate) mod infer; -use std::convert::TryFrom; use std::io::Write; use std::num::NonZeroUsize; use std::ops::Deref; -use arrow::array::{ArrayRef, StructArray}; use arrow::legacy::conversion::chunk_to_struct; use polars_core::error::to_compute_err; use polars_core::prelude::*; diff --git a/crates/polars-io/src/options.rs b/crates/polars-io/src/options.rs index fe219e317140c..a4f23e9cc272c 100644 --- a/crates/polars-io/src/options.rs +++ b/crates/polars-io/src/options.rs @@ -1,4 +1,4 @@ -use arrow::legacy::prelude::IdxSize; +use polars_utils::IdxSize; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; diff --git a/crates/polars-io/src/parquet/async_impl.rs b/crates/polars-io/src/parquet/async_impl.rs index c13e5a3c46f1d..8d9f40755c6da 100644 --- a/crates/polars-io/src/parquet/async_impl.rs +++ b/crates/polars-io/src/parquet/async_impl.rs @@ -1,26 +1,23 @@ //! Read parquet files in parallel from the Object Store without a third party crate. use std::ops::Range; -use std::sync::Arc; use arrow::datatypes::ArrowSchemaRef; use bytes::Bytes; use object_store::path::Path as ObjectPath; -use object_store::ObjectStore; use polars_core::config::{get_rg_prefetch_size, verbose}; -use polars_core::datatypes::PlHashMap; -use polars_core::error::{to_compute_err, PolarsResult}; +use polars_core::error::to_compute_err; use polars_core::prelude::*; -use polars_parquet::read::{self as parquet2_read, RowGroupMetaData}; +use polars_parquet::read::RowGroupMetaData; use polars_parquet::write::FileMetaData; use smartstring::alias::String as SmartString; use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::Mutex; -use super::cloud::{build_object_store, CloudLocation, CloudReader}; +use super::cloud::{build_object_store, CloudLocation}; use super::mmap::ColumnStore; -use crate::cloud::CloudOptions; +use crate::cloud::{CloudOptions, PolarsObjectStore}; use crate::parquet::read_impl::compute_row_group_range; -use crate::pl_async::{get_runtime, with_concurrency_budget, MAX_BUDGET_PER_REQUEST}; +use crate::pl_async::get_runtime; use crate::predicates::PhysicalIoExpr; use crate::prelude::predicates::read_this_row_group; @@ -29,9 +26,9 @@ type QueuePayload = (usize, DownloadedRowGroup); type QueueSend = Arc>>; pub struct ParquetObjectStore { - store: Arc, + store: PolarsObjectStore, path: ObjectPath, - length: Option, + length: Option, metadata: Option>, } @@ -41,64 +38,42 @@ impl ParquetObjectStore { options: Option<&CloudOptions>, metadata: Option>, ) -> PolarsResult { - let (CloudLocation { prefix, .. }, store) = build_object_store(uri, options).await?; + let ( + CloudLocation { + prefix, expansion, .. + }, + store, + ) = build_object_store(uri, options).await?; + + // Any wildcards should already have been resolved here. Without this assertion they would + // be ignored. + debug_assert!(expansion.is_none(), "path should not contain wildcards"); + let path = ObjectPath::from_url_path(prefix).map_err(to_compute_err)?; Ok(ParquetObjectStore { - store, - path: ObjectPath::from_url_path(prefix).map_err(to_compute_err)?, + store: PolarsObjectStore::new(store), + path, length: None, metadata, }) } async fn get_range(&self, start: usize, length: usize) -> PolarsResult { - with_concurrency_budget(1, || async { - self.store - .get_range(&self.path, start..start + length) - .await - .map_err(to_compute_err) - }) - .await + self.store + .get_range(&self.path, start..start + length) + .await } async fn get_ranges(&self, ranges: &[Range]) -> PolarsResult> { - // Object-store has a maximum of 10 concurrent. - with_concurrency_budget( - (ranges.len() as u32).clamp(0, MAX_BUDGET_PER_REQUEST as u32), - || async { - self.store - .get_ranges(&self.path, ranges) - .await - .map_err(to_compute_err) - }, - ) - .await + self.store.get_ranges(&self.path, ranges).await } /// Initialize the length property of the object, unless it has already been fetched. - async fn initialize_length(&mut self) -> PolarsResult<()> { - if self.length.is_some() { - return Ok(()); + async fn length(&mut self) -> PolarsResult { + if self.length.is_none() { + self.length = Some(self.store.head(&self.path).await?.size); } - with_concurrency_budget(1, || async { - self.length = Some( - self.store - .head(&self.path) - .await - .map_err(to_compute_err)? - .size as u64, - ); - Ok(()) - }) - .await - } - - pub async fn schema(&mut self) -> PolarsResult { - let metadata = self.get_metadata().await?; - - let arrow_schema = parquet2_read::infer_schema(metadata)?; - - Ok(Arc::new(arrow_schema)) + Ok(self.length.unwrap()) } /// Number of rows in the parquet file. @@ -109,18 +84,8 @@ impl ParquetObjectStore { /// Fetch the metadata of the parquet file, do not memoize it. async fn fetch_metadata(&mut self) -> PolarsResult { - self.initialize_length().await?; - let object_store = self.store.clone(); - let path = self.path.clone(); - let length = self.length; - let mut reader = CloudReader::new(length, object_store, path); - - with_concurrency_budget(1, || async { - parquet2_read::read_metadata_async(&mut reader) - .await - .map_err(to_compute_err) - }) - .await + let length = self.length().await?; + fetch_metadata(&self.store, &self.path, length).await } /// Fetch and memoize the metadata of the parquet file. @@ -132,6 +97,79 @@ impl ParquetObjectStore { } } +fn read_n(reader: &mut &[u8]) -> Option<[u8; N]> { + if N <= reader.len() { + let (head, tail) = reader.split_at(N); + *reader = tail; + Some(head.try_into().unwrap()) + } else { + None + } +} + +fn read_i32le(reader: &mut &[u8]) -> Option { + read_n(reader).map(i32::from_le_bytes) +} + +/// Asynchronously reads the files' metadata +pub async fn fetch_metadata( + store: &PolarsObjectStore, + path: &ObjectPath, + file_byte_length: usize, +) -> PolarsResult { + let footer_header_bytes = store + .get_range( + path, + file_byte_length + .checked_sub(polars_parquet::parquet::FOOTER_SIZE as usize) + .ok_or_else(|| { + polars_parquet::parquet::error::Error::OutOfSpec( + "not enough bytes to contain parquet footer".to_string(), + ) + })?..file_byte_length, + ) + .await?; + + let footer_byte_length: usize = { + let reader = &mut footer_header_bytes.as_ref(); + let footer_byte_size = read_i32le(reader).unwrap(); + let magic = read_n(reader).unwrap(); + debug_assert!(reader.is_empty()); + if magic != polars_parquet::parquet::PARQUET_MAGIC { + return Err(polars_parquet::parquet::error::Error::OutOfSpec( + "incorrect magic in parquet footer".to_string(), + ) + .into()); + } + footer_byte_size.try_into().map_err(|_| { + polars_parquet::parquet::error::Error::OutOfSpec( + "negative footer byte length".to_string(), + ) + })? + }; + + let footer_bytes = store + .get_range( + path, + file_byte_length + .checked_sub(polars_parquet::parquet::FOOTER_SIZE as usize + footer_byte_length) + .ok_or_else(|| { + polars_parquet::parquet::error::Error::OutOfSpec( + "not enough bytes to contain parquet footer".to_string(), + ) + })?..file_byte_length, + ) + .await?; + + Ok(polars_parquet::parquet::read::deserialize_metadata( + std::io::Cursor::new(footer_bytes.as_ref()), + // TODO: Describe why this makes sense. Taken from the previous + // implementation which said "a highly nested but sparse struct could + // result in many allocations". + footer_bytes.as_ref().len() * 2 + 1024, + )?) +} + /// Download rowgroups for the column whose indexes are given in `projection`. /// We concurrently download the columns for each field. async fn download_projection( diff --git a/crates/polars-io/src/parquet/mmap.rs b/crates/polars-io/src/parquet/mmap.rs index c7ece315d1f30..38013da5febef 100644 --- a/crates/polars-io/src/parquet/mmap.rs +++ b/crates/polars-io/src/parquet/mmap.rs @@ -1,8 +1,6 @@ use arrow::datatypes::Field; #[cfg(feature = "async")] use bytes::Bytes; -#[cfg(feature = "async")] -use polars_core::datatypes::PlHashMap; use polars_parquet::read::{ column_iter_to_arrays, get_field_columns, ArrayIter, BasicDecompressor, ColumnChunkMetaData, PageReader, diff --git a/crates/polars-io/src/parquet/read.rs b/crates/polars-io/src/parquet/read.rs index cac3347fd464e..5691857a4ac4c 100644 --- a/crates/polars-io/src/parquet/read.rs +++ b/crates/polars-io/src/parquet/read.rs @@ -1,12 +1,10 @@ use std::io::{Read, Seek}; -use std::sync::Arc; use arrow::datatypes::ArrowSchemaRef; use polars_core::prelude::*; #[cfg(feature = "cloud")] use polars_core::utils::accumulate_dataframes_vertical_unchecked; use polars_parquet::read; -use polars_parquet::write::FileMetaData; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -162,6 +160,7 @@ impl ParquetReader { chunk_size, self.use_statistics, self.hive_partition_columns, + self.parallel, ) } } @@ -233,6 +232,7 @@ pub struct ParquetAsyncReader { use_statistics: bool, hive_partition_columns: Option>, schema: Option, + parallel: ParallelStrategy, } #[cfg(feature = "cloud")] @@ -253,15 +253,21 @@ impl ParquetAsyncReader { use_statistics: true, hive_partition_columns: None, schema, + parallel: Default::default(), }) } pub async fn schema(&mut self) -> PolarsResult { - match &self.schema { - Some(schema) => Ok(schema.clone()), - None => self.reader.schema().await, - } + Ok(match self.schema.as_ref() { + Some(schema) => Arc::clone(schema), + None => { + let metadata = self.reader.get_metadata().await?; + let arrow_schema = polars_parquet::arrow::read::infer_schema(metadata)?; + Arc::new(arrow_schema) + }, + }) } + pub async fn num_rows(&mut self) -> PolarsResult { self.reader.num_rows().await } @@ -303,6 +309,11 @@ impl ParquetAsyncReader { self } + pub fn read_parallel(mut self, parallel: ParallelStrategy) -> Self { + self.parallel = parallel; + self + } + pub async fn batched(mut self, chunk_size: usize) -> PolarsResult { let metadata = self.reader.get_metadata().await?.clone(); let schema = match self.schema { @@ -330,6 +341,7 @@ impl ParquetAsyncReader { chunk_size, self.use_statistics, self.hive_partition_columns, + self.parallel, ) } diff --git a/crates/polars-io/src/parquet/read_impl.rs b/crates/polars-io/src/parquet/read_impl.rs index c3104c974f269..096fa9708c7a5 100644 --- a/crates/polars-io/src/parquet/read_impl.rs +++ b/crates/polars-io/src/parquet/read_impl.rs @@ -1,8 +1,6 @@ use std::borrow::Cow; use std::collections::VecDeque; -use std::convert::TryFrom; use std::ops::{Deref, Range}; -use std::sync::Arc; use arrow::array::new_empty_array; use arrow::datatypes::ArrowSchemaRef; @@ -111,7 +109,10 @@ pub(super) fn array_iter_to_series( /// Materializes hive partitions. /// We have a special num_rows arg, as df can be empty when a projection contains /// only hive partition columns. -/// Safety: num_rows equals the height of the df when the df height is non-zero. +/// +/// # Safety +/// +/// num_rows equals the height of the df when the df height is non-zero. pub(crate) fn materialize_hive_partitions( df: &mut DataFrame, hive_partition_columns: Option<&[Series]>, @@ -245,7 +246,7 @@ fn rg_to_dfs_optionally_par_over_columns( *remaining_rows -= projection_height; - let mut df = DataFrame::new_no_checks(columns); + let mut df = unsafe { DataFrame::new_no_checks(columns) }; if let Some(rc) = &row_index { df.with_row_index_mut(&rc.name, Some(*previous_row_count + rc.offset)); } @@ -332,7 +333,7 @@ fn rg_to_dfs_par_over_rg( }) .collect::>>()?; - let mut df = DataFrame::new_no_checks(columns); + let mut df = unsafe { DataFrame::new_no_checks(columns) }; if let Some(rc) = &row_index { df.with_row_index_mut(&rc.name, Some(row_count_start as IdxSize + rc.offset)); @@ -517,7 +518,7 @@ pub struct BatchedParquetReader { #[allow(dead_code)] row_group_fetcher: RowGroupFetcher, limit: usize, - projection: Vec, + projection: Arc<[usize]>, schema: ArrowSchemaRef, metadata: FileMetaDataRef, predicate: Option>, @@ -529,7 +530,7 @@ pub struct BatchedParquetReader { parallel: ParallelStrategy, chunk_size: usize, use_statistics: bool, - hive_partition_columns: Option>, + hive_partition_columns: Option>, /// Has returned at least one materialized frame. has_returned: bool, } @@ -547,16 +548,27 @@ impl BatchedParquetReader { chunk_size: usize, use_statistics: bool, hive_partition_columns: Option>, + mut parallel: ParallelStrategy, ) -> PolarsResult { let n_row_groups = metadata.row_groups.len(); - let projection = projection.unwrap_or_else(|| (0usize..schema.len()).collect::>()); + let projection = projection + .map(Arc::from) + .unwrap_or_else(|| (0usize..schema.len()).collect::>()); + + parallel = match parallel { + ParallelStrategy::Auto => { + if n_row_groups > projection.len() || n_row_groups > POOL.current_num_threads() { + ParallelStrategy::RowGroups + } else { + ParallelStrategy::Columns + } + }, + _ => parallel, + }; - let parallel = - if n_row_groups > projection.len() || n_row_groups > POOL.current_num_threads() { - ParallelStrategy::RowGroups - } else { - ParallelStrategy::Columns - }; + if let (ParallelStrategy::Columns, true) = (parallel, projection.len() == 1) { + parallel = ParallelStrategy::None; + } Ok(BatchedParquetReader { row_group_fetcher, @@ -573,7 +585,7 @@ impl BatchedParquetReader { parallel, chunk_size, use_statistics, - hive_partition_columns, + hive_partition_columns: hive_partition_columns.map(Arc::from), has_returned: false, }) } @@ -596,7 +608,13 @@ impl BatchedParquetReader { pub async fn next_batches(&mut self, n: usize) -> PolarsResult>> { if self.limit == 0 && self.has_returned { - return Ok(None); + return if self.chunks_fifo.is_empty() { + Ok(None) + } else { + // the range end point must not be greater than the length of the deque + let n_drainable = std::cmp::min(n, self.chunks_fifo.len()); + Ok(Some(self.chunks_fifo.drain(..n_drainable).collect())) + }; } let mut skipped_all_rgs = false; @@ -616,21 +634,66 @@ impl BatchedParquetReader { .fetch_row_groups(row_group_start..row_group_end) .await?; - let dfs = rg_to_dfs( - &store, - &mut self.rows_read, - row_group_start, - row_group_end, - &mut self.limit, - &self.metadata, - &self.schema, - self.predicate.as_deref(), - self.row_index.clone(), - self.parallel, - &self.projection, - self.use_statistics, - self.hive_partition_columns.as_deref(), - )?; + let dfs = match store { + ColumnStore::Local(_) => rg_to_dfs( + &store, + &mut self.rows_read, + row_group_start, + row_group_end, + &mut self.limit, + &self.metadata, + &self.schema, + self.predicate.as_deref(), + self.row_index.clone(), + self.parallel, + &self.projection, + self.use_statistics, + self.hive_partition_columns.as_deref(), + ), + #[cfg(feature = "async")] + ColumnStore::Fetched(b) => { + // This branch we spawn the decoding and decompression of the bytes on a rayon task. + // This will ensure we don't block the async thread. + + // Reconstruct as that makes it a 'static. + let store = ColumnStore::Fetched(b); + let (tx, rx) = tokio::sync::oneshot::channel(); + + // Make everything 'static. + let mut rows_read = self.rows_read; + let mut limit = self.limit; + let row_index = self.row_index.clone(); + let predicate = self.predicate.clone(); + let schema = self.schema.clone(); + let metadata = self.metadata.clone(); + let parallel = self.parallel; + let projection = self.projection.clone(); + let use_statistics = self.use_statistics; + let hive_partition_columns = self.hive_partition_columns.clone(); + POOL.spawn(move || { + let dfs = rg_to_dfs( + &store, + &mut rows_read, + row_group_start, + row_group_end, + &mut limit, + &metadata, + &schema, + predicate.as_deref(), + row_index, + parallel, + &projection, + use_statistics, + hive_partition_columns.as_deref(), + ); + tx.send((dfs, rows_read, limit)).unwrap(); + }); + let (dfs, rows_read, limit) = rx.await.unwrap(); + self.rows_read = rows_read; + self.limit = limit; + dfs + }, + }?; self.row_group_offset += n; @@ -665,8 +728,8 @@ impl BatchedParquetReader { if self.chunks_fifo.is_empty() { if skipped_all_rgs { Ok(Some(vec![materialize_empty_df( - Some(self.projection.as_slice()), - self.schema(), + Some(self.projection.as_ref()), + &self.schema, self.hive_partition_columns.as_deref(), self.row_index.as_ref(), )])) diff --git a/crates/polars-io/src/parquet/write.rs b/crates/polars-io/src/parquet/write.rs index 10694d8587811..1028149ff0a5c 100644 --- a/crates/polars-io/src/parquet/write.rs +++ b/crates/polars-io/src/parquet/write.rs @@ -1,19 +1,22 @@ +use std::borrow::Cow; +use std::collections::VecDeque; use std::io::Write; +use std::sync::Mutex; -use arrow::array::{Array, ArrayRef}; +use arrow::array::Array; use arrow::chunk::Chunk; -use arrow::datatypes::{ArrowDataType, PhysicalType}; +use arrow::datatypes::PhysicalType; use polars_core::prelude::*; -use polars_core::utils::{accumulate_dataframes_vertical_unchecked, split_df}; +use polars_core::utils::{accumulate_dataframes_vertical_unchecked, split_df_as_ref}; use polars_core::POOL; use polars_parquet::read::ParquetError; -use polars_parquet::write::{self, DynIter, DynStreamingIterator, Encoding, FileWriter, *}; +pub use polars_parquet::write::RowGroupIter; +use polars_parquet::write::{self, *}; use rayon::prelude::*; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use write::{ - BrotliLevel as BrotliLevelParquet, CompressionOptions, GzipLevel as GzipLevelParquet, - ZstdLevel as ZstdLevelParquet, + BrotliLevel as BrotliLevelParquet, GzipLevel as GzipLevelParquet, ZstdLevel as ZstdLevelParquet, }; #[derive(Debug, Eq, PartialEq, Hash, Clone, Copy)] @@ -175,7 +178,7 @@ where let parquet_schema = to_parquet_schema(&schema)?; let encodings = get_encodings(&schema); let options = self.materialize_options(); - let writer = FileWriter::try_new(self.writer, schema, options)?; + let writer = Mutex::new(FileWriter::try_new(self.writer, schema, options)?); Ok(BatchedWriter { writer, @@ -192,11 +195,26 @@ where df.align_chunks(); let n_splits = df.height() / self.row_group_size.unwrap_or(512 * 512); - if n_splits > 0 { - *df = accumulate_dataframes_vertical_unchecked(split_df(df, n_splits)?); - } + let chunked_df = if n_splits > 0 { + Cow::Owned(accumulate_dataframes_vertical_unchecked( + split_df_as_ref(df, n_splits, false)? + .into_iter() + .map(|mut df| { + // If the chunks are small enough, writing many small chunks + // leads to slow writing performance, so in that case we + // merge them. + let n_chunks = df.n_chunks(); + if n_chunks > 1 && (df.estimated_size() / n_chunks < 128 * 1024) { + df.as_single_chunk_par(); + } + df + }), + )) + } else { + Cow::Borrowed(df) + }; let mut batched = self.batched(&df.schema())?; - batched.write_batch(df)?; + batched.write_batch(&chunked_df)?; batched.finish() } } @@ -208,7 +226,7 @@ fn prepare_rg_iter<'a>( encodings: &'a [Vec], options: WriteOptions, parallel: bool, -) -> impl Iterator>> + 'a { +) -> impl Iterator>> + 'a { let rb_iter = df.iter_chunks(true); rb_iter.filter_map(move |batch| match batch.len() { 0 => None, @@ -232,9 +250,11 @@ fn get_encodings(schema: &ArrowSchema) -> Vec> { /// Declare encodings fn encoding_map(data_type: &ArrowDataType) -> Encoding { match data_type.to_physical_type() { - PhysicalType::Dictionary(_) | PhysicalType::LargeBinary | PhysicalType::LargeUtf8 => { - Encoding::RleDictionary - }, + PhysicalType::Dictionary(_) + | PhysicalType::LargeBinary + | PhysicalType::LargeUtf8 + | PhysicalType::Utf8View + | PhysicalType::BinaryView => Encoding::RleDictionary, PhysicalType::Primitive(dt) => { use arrow::types::PrimitiveType::*; match dt { @@ -248,7 +268,9 @@ fn encoding_map(data_type: &ArrowDataType) -> Encoding { } pub struct BatchedWriter { - writer: FileWriter, + // A mutex so that streaming engine can get concurrent read access to + // compress pages. + writer: Mutex>, parquet_schema: SchemaDescriptor, encodings: Vec>, options: WriteOptions, @@ -256,6 +278,26 @@ pub struct BatchedWriter { } impl BatchedWriter { + pub fn encode_and_compress<'a>( + &'a self, + df: &'a DataFrame, + ) -> impl Iterator>> + 'a { + let rb_iter = df.iter_chunks(true); + rb_iter.filter_map(move |batch| match batch.len() { + 0 => None, + _ => { + let row_group = create_eager_serializer( + batch, + self.parquet_schema.fields(), + self.encodings.as_ref(), + self.options, + ); + + Some(row_group) + }, + }) + } + /// Write a batch to the parquet writer. /// /// # Panics @@ -268,26 +310,45 @@ impl BatchedWriter { self.options, self.parallel, ); + // Lock before looping so that order is maintained under contention. + let mut writer = self.writer.lock().unwrap(); for group in row_group_iter { - self.writer.write(group?)?; + writer.write(group?)?; + } + Ok(()) + } + + pub fn get_writer(&self) -> &Mutex> { + &self.writer + } + + pub fn write_row_groups( + &self, + rgs: Vec>, + ) -> PolarsResult<()> { + // Lock before looping so that order is maintained. + let mut writer = self.writer.lock().unwrap(); + for group in rgs { + writer.write(group)?; } Ok(()) } /// Writes the footer of the parquet file. Returns the total size of the file. - pub fn finish(&mut self) -> PolarsResult { - let size = self.writer.end(None)?; + pub fn finish(&self) -> PolarsResult { + let mut writer = self.writer.lock().unwrap(); + let size = writer.end(None)?; Ok(size) } } -fn create_serializer<'a>( +fn create_serializer( batch: Chunk>, fields: &[ParquetType], encodings: &[Vec], options: WriteOptions, parallel: bool, -) -> PolarsResult> { +) -> PolarsResult> { let func = move |((array, type_), encoding): ((&ArrayRef, &ParquetType), &Vec)| { let encoded_columns = array_to_columns(array, type_.clone(), options, encoding).unwrap(); @@ -339,3 +400,74 @@ fn create_serializer<'a>( Ok(row_group) } + +struct CompressedPages { + pages: VecDeque>, + current: Option, +} + +impl CompressedPages { + fn new(pages: VecDeque>) -> Self { + Self { + pages, + current: None, + } + } +} + +impl FallibleStreamingIterator for CompressedPages { + type Item = CompressedPage; + type Error = PolarsError; + + fn advance(&mut self) -> Result<(), Self::Error> { + self.current = self.pages.pop_front().transpose()?; + Ok(()) + } + + fn get(&self) -> Option<&Self::Item> { + self.current.as_ref() + } +} + +/// This serializer encodes and compresses all eagerly in memory. +/// Used for separating compute from IO. +fn create_eager_serializer( + batch: Chunk>, + fields: &[ParquetType], + encodings: &[Vec], + options: WriteOptions, +) -> PolarsResult> { + let func = move |((array, type_), encoding): ((&ArrayRef, &ParquetType), &Vec)| { + let encoded_columns = array_to_columns(array, type_.clone(), options, encoding).unwrap(); + + encoded_columns + .into_iter() + .map(|encoded_pages| { + let compressed_pages = encoded_pages + .into_iter() + .map(|page| { + let page = page?; + let page = compress(page, vec![], options.compression)?; + Ok(Ok(page)) + }) + .collect::>>()?; + + Ok(DynStreamingIterator::new(CompressedPages::new( + compressed_pages, + ))) + }) + .collect::>() + }; + + let columns = batch + .columns() + .iter() + .zip(fields) + .zip(encodings) + .flat_map(func) + .collect::>(); + + let row_group = DynIter::new(columns.into_iter()); + + Ok(row_group) +} diff --git a/crates/polars-io/src/pl_async.rs b/crates/polars-io/src/pl_async.rs index d5868d9cd4b95..be8b9146307e7 100644 --- a/crates/polars-io/src/pl_async.rs +++ b/crates/polars-io/src/pl_async.rs @@ -1,9 +1,12 @@ +use std::error::Error; use std::future::Future; use std::ops::Deref; +use std::sync::atomic::{AtomicBool, AtomicU64, AtomicU8, Ordering}; use std::sync::RwLock; use std::thread::ThreadId; use once_cell::sync::Lazy; +use polars_core::config::verbose; use polars_core::POOL; use polars_utils::aliases::PlHashSet; use tokio::runtime::{Builder, Runtime}; @@ -12,17 +15,195 @@ use tokio::sync::Semaphore; static CONCURRENCY_BUDGET: std::sync::OnceLock<(Semaphore, u32)> = std::sync::OnceLock::new(); pub(super) const MAX_BUDGET_PER_REQUEST: usize = 10; +pub trait GetSize { + fn size(&self) -> u64; +} + +impl GetSize for bytes::Bytes { + fn size(&self) -> u64 { + self.len() as u64 + } +} + +impl GetSize for Vec { + fn size(&self) -> u64 { + self.iter().map(|v| v.size()).sum() + } +} + +impl GetSize for Result { + fn size(&self) -> u64 { + match self { + Ok(v) => v.size(), + Err(_) => 0, + } + } +} + +enum Optimization { + Step, + Accept, + Finished, +} + +struct SemaphoreTuner { + previous_download_speed: u64, + last_tune: std::time::Instant, + downloaded: AtomicU64, + download_time: AtomicU64, + opt_state: Optimization, + increments: u32, +} + +impl SemaphoreTuner { + fn new() -> Self { + Self { + previous_download_speed: 0, + last_tune: std::time::Instant::now(), + downloaded: AtomicU64::new(0), + download_time: AtomicU64::new(0), + opt_state: Optimization::Step, + increments: 0, + } + } + fn should_tune(&self) -> bool { + match self.opt_state { + Optimization::Finished => false, + _ => self.last_tune.elapsed().as_millis() > 350, + } + } + + fn add_stats(&self, downloaded_bytes: u64, download_time: u64) { + self.downloaded + .fetch_add(downloaded_bytes, Ordering::Relaxed); + self.download_time + .fetch_add(download_time, Ordering::Relaxed); + } + + fn increment(&mut self, semaphore: &Semaphore) { + semaphore.add_permits(1); + self.increments += 1; + } + + fn tune(&mut self, semaphore: &'static Semaphore) -> bool { + let download_speed = self.downloaded.fetch_add(0, Ordering::Relaxed) + / self.download_time.fetch_add(0, Ordering::Relaxed); + + let increased = download_speed > self.previous_download_speed; + self.previous_download_speed = download_speed; + match self.opt_state { + Optimization::Step => { + self.increment(semaphore); + self.opt_state = Optimization::Accept + }, + Optimization::Accept => { + // Accept the step + if increased { + // Set new step + self.increment(semaphore); + // Keep accept state to check next iteration + } + // Decline the step + else { + self.opt_state = Optimization::Finished; + FINISHED_TUNING.store(true, Ordering::Relaxed); + if verbose() { + eprintln!( + "concurrency tuner finished after adding {} steps", + self.increments + ) + } + // Finished. + return true; + } + }, + Optimization::Finished => {}, + } + self.last_tune = std::time::Instant::now(); + // Not finished. + false + } +} +static INCR: AtomicU8 = AtomicU8::new(0); +static FINISHED_TUNING: AtomicBool = AtomicBool::new(false); +static PERMIT_STORE: std::sync::OnceLock> = + std::sync::OnceLock::new(); + +fn get_semaphore() -> &'static (Semaphore, u32) { + CONCURRENCY_BUDGET.get_or_init(|| { + let permits = std::env::var("POLARS_CONCURRENCY_BUDGET") + .map(|s| { + let budget = s.parse::().expect("integer"); + FINISHED_TUNING.store(true, Ordering::Relaxed); + budget + }) + .unwrap_or_else(|_| std::cmp::max(POOL.current_num_threads(), MAX_BUDGET_PER_REQUEST)); + (Semaphore::new(permits), permits as u32) + }) +} + +pub async fn tune_with_concurrency_budget(requested_budget: u32, callable: F) -> Fut::Output +where + F: FnOnce() -> Fut, + Fut: Future, + Fut::Output: GetSize, +{ + let (semaphore, initial_budget) = get_semaphore(); + + // This would never finish otherwise. + assert!(requested_budget <= *initial_budget); + + // Keep permit around. + // On drop it is returned to the semaphore. + let _permit_acq = semaphore.acquire_many(requested_budget).await.unwrap(); + + let now = std::time::Instant::now(); + let res = callable().await; + + if FINISHED_TUNING.load(Ordering::Relaxed) || res.size() == 0 { + return res; + } + + let duration = now.elapsed().as_millis() as u64; + let permit_store = PERMIT_STORE.get_or_init(|| tokio::sync::RwLock::new(SemaphoreTuner::new())); + + let Ok(tuner) = permit_store.try_read() else { + return res; + }; + // Keep track of download speed + tuner.add_stats(res.size(), duration); + + // We only tune every n ms + if !tuner.should_tune() { + return res; + } + // Drop the read tuner before trying to acquire a writer + drop(tuner); + + // Reduce locking by letting only 1 in 5 tasks lock the tuner + if (INCR.fetch_add(1, Ordering::Relaxed) % 5) != 0 { + return res; + } + // Never lock as we will deadlock. This can run under rayon + let Ok(mut tuner) = permit_store.try_write() else { + return res; + }; + let finished = tuner.tune(semaphore); + if finished { + drop(_permit_acq); + // Undo the last step + let undo = semaphore.acquire().await.unwrap(); + std::mem::forget(undo) + } + res +} + pub async fn with_concurrency_budget(requested_budget: u32, callable: F) -> Fut::Output where F: FnOnce() -> Fut, Fut: Future, { - let (semaphore, initial_budget) = CONCURRENCY_BUDGET.get_or_init(|| { - let permits = std::env::var("POLARS_CONCURRENCY_BUDGET") - .map(|s| s.parse::().expect("integer")) - .unwrap_or_else(|_| std::cmp::max(POOL.current_num_threads(), MAX_BUDGET_PER_REQUEST)); - (Semaphore::new(permits), permits as u32) - }); + let (semaphore, initial_budget) = get_semaphore(); // This would never finish otherwise. assert!(requested_budget <= *initial_budget); @@ -30,6 +211,7 @@ where // Keep permit around. // On drop it is returned to the semaphore. let _permit_acq = semaphore.acquire_many(requested_budget).await.unwrap(); + callable().await } diff --git a/crates/polars-io/src/predicates.rs b/crates/polars-io/src/predicates.rs index 48aec098702a4..7c3dca6b654e0 100644 --- a/crates/polars-io/src/predicates.rs +++ b/crates/polars-io/src/predicates.rs @@ -20,7 +20,7 @@ pub trait StatsEvaluator { } #[cfg(feature = "parquet")] -pub(crate) fn apply_predicate( +pub fn apply_predicate( df: &mut DataFrame, predicate: Option<&dyn PhysicalIoExpr>, parallel: bool, diff --git a/crates/polars-io/src/utils.rs b/crates/polars-io/src/utils.rs index f6945156b990d..114fc979b6eb2 100644 --- a/crates/polars-io/src/utils.rs +++ b/crates/polars-io/src/utils.rs @@ -2,19 +2,10 @@ use std::io::Read; use std::path::{Path, PathBuf}; use once_cell::sync::Lazy; -#[cfg(any(feature = "csv", feature = "json"))] -use polars_core::frame::DataFrame; use polars_core::prelude::*; use regex::{Regex, RegexBuilder}; use crate::mmap::{MmapBytesReader, ReaderBytes}; -#[cfg(any( - feature = "ipc", - feature = "ipc_streaming", - feature = "parquet", - feature = "avro" -))] -use crate::ArrowSchema; pub fn get_reader_bytes<'a, R: Read + MmapBytesReader + ?Sized>( reader: &'a mut R, diff --git a/crates/polars-lazy/src/dsl/eval.rs b/crates/polars-lazy/src/dsl/eval.rs index 95dbf6b5f97db..2eae44388117e 100644 --- a/crates/polars-lazy/src/dsl/eval.rs +++ b/crates/polars-lazy/src/dsl/eval.rs @@ -82,7 +82,7 @@ pub trait ExprEvalExtension: IntoExpr + Sized { .map(|len| { let s = s.slice(0, len); if (len - s.null_count()) >= min_periods { - let df = DataFrame::new_no_checks(vec![s]); + let df = s.into_frame(); let out = phys_expr.evaluate(&df, &state)?; finish(out) } else { @@ -91,7 +91,7 @@ pub trait ExprEvalExtension: IntoExpr + Sized { }) .collect::>>()? } else { - let mut df_container = DataFrame::new_no_checks(vec![]); + let mut df_container = DataFrame::empty(); (1..s.len() + 1) .map(|len| { let s = s.slice(0, len); diff --git a/crates/polars-lazy/src/dsl/list.rs b/crates/polars-lazy/src/dsl/list.rs index 9a4a97b5993bf..9d353a25c052a 100644 --- a/crates/polars-lazy/src/dsl/list.rs +++ b/crates/polars-lazy/src/dsl/list.rs @@ -61,7 +61,7 @@ fn run_per_sublist( .par_iter() .map(|opt_s| { opt_s.and_then(|s| { - let df = DataFrame::new_no_checks(vec![s]); + let df = s.into_frame(); let out = phys_expr.evaluate(&df, &state); match out { Ok(s) => Some(s), @@ -76,7 +76,7 @@ fn run_per_sublist( err = m_err.into_inner().unwrap(); ca } else { - let mut df_container = DataFrame::new_no_checks(vec![]); + let mut df_container = DataFrame::empty(); lst.into_iter() .map(|s| { @@ -124,7 +124,7 @@ fn run_on_group_by_engine( // Invariant in List means values physicals can be cast to inner dtype let values = unsafe { values.cast_unchecked(&inner_dtype).unwrap() }; - let df_context = DataFrame::new_no_checks(vec![values]); + let df_context = values.into_frame(); let phys_expr = prepare_expression_for_context("", expr, &inner_dtype, Context::Aggregation)?; let state = ExecutionState::new(); diff --git a/crates/polars-lazy/src/frame/mod.rs b/crates/polars-lazy/src/frame/mod.rs index f0844ffe7262f..7bd67b062052e 100644 --- a/crates/polars-lazy/src/frame/mod.rs +++ b/crates/polars-lazy/src/frame/mod.rs @@ -19,7 +19,6 @@ use std::path::PathBuf; use std::sync::Arc; pub use anonymous_scan::*; -use arrow::legacy::prelude::QuantileInterpolOptions; #[cfg(feature = "csv")] pub use csv::*; #[cfg(not(target_arch = "wasm32"))] @@ -31,20 +30,10 @@ pub use ipc::*; pub use ndjson::*; #[cfg(feature = "parquet")] pub use parquet::*; -use polars_core::frame::explode::MeltArgs; use polars_core::prelude::*; use polars_io::RowIndex; pub use polars_plan::frame::{AllowedOptimizations, OptState}; use polars_plan::global::FETCH_ROWS; -#[cfg(any( - feature = "ipc", - feature = "parquet", - feature = "csv", - feature = "json" -))] -use polars_plan::logical_plan::collect_fingerprints; -use polars_plan::logical_plan::optimize; -use polars_plan::utils::expr_output_name; use smartstring::alias::String as SmartString; use crate::fallible; @@ -151,6 +140,7 @@ impl LazyFrame { streaming: false, eager: false, fast_projection: false, + row_estimate: false, }) } @@ -198,12 +188,19 @@ impl LazyFrame { self } - /// Allow (partial) streaming engine. + /// Run nodes that are capably of doing so on the streaming engine. pub fn with_streaming(mut self, toggle: bool) -> Self { self.opt_state.streaming = toggle; self } + /// Try to estimate the number of rows so that joins can determine which side to keep in memory. + pub fn with_row_estimate(mut self, toggle: bool) -> Self { + self.opt_state.row_estimate = toggle; + self + } + + /// Run every node eagerly. This turns off multi-node optimizations. pub fn _with_eager(mut self, toggle: bool) -> Self { self.opt_state.eager = toggle; self @@ -613,7 +610,15 @@ impl LazyFrame { if streaming { #[cfg(feature = "streaming")] { - insert_streaming_nodes(lp_top, lp_arena, expr_arena, scratch, _fmt, true)?; + insert_streaming_nodes( + lp_top, + lp_arena, + expr_arena, + scratch, + _fmt, + true, + opt_state.row_estimate, + )?; } #[cfg(not(feature = "streaming"))] { @@ -1113,7 +1118,7 @@ impl LazyFrame { ) } - /// Creates the cartesian product from both frames, preserving the order of the left keys. + /// Creates the Cartesian product from both frames, preserving the order of the left keys. #[cfg(feature = "cross_join")] pub fn cross_join(self, other: LazyFrame) -> LazyFrame { self.join(other, vec![], vec![], JoinArgs::new(JoinType::Cross)) @@ -1411,7 +1416,11 @@ impl LazyFrame { /// - String columns will sum to None. pub fn sum(self) -> PolarsResult { self.stats_helper( - |dt| dt.is_numeric() || matches!(dt, DataType::Boolean | DataType::Duration(_)), + |dt| { + dt.is_numeric() + || dt.is_decimal() + || matches!(dt, DataType::Boolean | DataType::Duration(_)) + }, |name| col(name).sum(), ) } @@ -1933,7 +1942,7 @@ impl JoinBuilder { /// The passed expressions must be valid in both `LazyFrame`s in the join. pub fn on>(mut self, on: E) -> Self { let on = on.as_ref().to_vec(); - self.left_on = on.clone(); + self.left_on.clone_from(&on); self.right_on = on; self } diff --git a/crates/polars-lazy/src/frame/pivot.rs b/crates/polars-lazy/src/frame/pivot.rs index c9e0339593dbb..e7254ea0d908f 100644 --- a/crates/polars-lazy/src/frame/pivot.rs +++ b/crates/polars-lazy/src/frame/pivot.rs @@ -31,11 +31,11 @@ impl PhysicalAggExpr for PivotExpr { } } -pub fn pivot( +pub fn pivot( df: &DataFrame, - values: I0, - index: I1, - columns: I2, + index: I0, + columns: I1, + values: Option, sort_columns: bool, agg_expr: Option, // used as separator/delimiter in generated column names. @@ -43,10 +43,10 @@ pub fn pivot( ) -> PolarsResult where I0: IntoIterator, - S0: AsRef, I1: IntoIterator, - S1: AsRef, I2: IntoIterator, + S0: AsRef, + S1: AsRef, S2: AsRef, { // make sure that the root column is replaced @@ -56,20 +56,20 @@ where }); polars_ops::pivot::pivot( df, - values, index, columns, + values, sort_columns, agg_expr, separator, ) } -pub fn pivot_stable( +pub fn pivot_stable( df: &DataFrame, - values: I0, - index: I1, - columns: I2, + index: I0, + columns: I1, + values: Option, sort_columns: bool, agg_expr: Option, // used as separator/delimiter in generated column names. @@ -77,10 +77,10 @@ pub fn pivot_stable( ) -> PolarsResult where I0: IntoIterator, - S0: AsRef, I1: IntoIterator, - S1: AsRef, I2: IntoIterator, + S0: AsRef, + S1: AsRef, S2: AsRef, { // make sure that the root column is replaced @@ -90,9 +90,9 @@ where }); polars_ops::pivot::pivot_stable( df, - values, index, columns, + values, sort_columns, agg_expr, separator, diff --git a/crates/polars-lazy/src/physical_plan/executors/group_by_dynamic.rs b/crates/polars-lazy/src/physical_plan/executors/group_by_dynamic.rs index a995d248b25c7..8d758ffc9f4ed 100644 --- a/crates/polars-lazy/src/physical_plan/executors/group_by_dynamic.rs +++ b/crates/polars-lazy/src/physical_plan/executors/group_by_dynamic.rs @@ -1,8 +1,3 @@ -#[cfg(feature = "dynamic_group_by")] -use polars_core::frame::group_by::GroupBy; -#[cfg(feature = "dynamic_group_by")] -use polars_time::DynamicGroupOptions; - use super::*; #[cfg_attr(not(feature = "dynamic_group_by"), allow(dead_code))] diff --git a/crates/polars-lazy/src/physical_plan/executors/group_by_partitioned.rs b/crates/polars-lazy/src/physical_plan/executors/group_by_partitioned.rs index f99aa2cd618ef..d10a09dffcbbd 100644 --- a/crates/polars-lazy/src/physical_plan/executors/group_by_partitioned.rs +++ b/crates/polars-lazy/src/physical_plan/executors/group_by_partitioned.rs @@ -1,5 +1,5 @@ +use polars_core::series::IsSorted; use polars_core::utils::{accumulate_dataframes_vertical, split_df}; -use polars_core::POOL; use rayon::prelude::*; use super::*; @@ -149,13 +149,19 @@ fn estimate_unique_count(keys: &[Series], mut sample_size: usize) -> PolarsResul .iter() .map(|s| s.slice(offset, sample_size)) .collect::>(); - let df = DataFrame::new_no_checks(keys); + let df = unsafe { DataFrame::new_no_checks(keys) }; let names = df.get_column_names(); let gb = df.group_by(names).unwrap(); Ok(finish(gb.get_groups())) } } +// Lower this at debug builds so that we hit this in the test suite. +#[cfg(debug_assertions)] +const PARTITION_LIMIT: usize = 15; +#[cfg(not(debug_assertions))] +const PARTITION_LIMIT: usize = 1000; + // Checks if we should run normal or default aggregation // by sampling data. fn can_run_partitioned( @@ -164,7 +170,16 @@ fn can_run_partitioned( state: &ExecutionState, from_partitioned_ds: bool, ) -> PolarsResult { - if std::env::var("POLARS_NO_PARTITION").is_ok() { + if !keys + .iter() + .take(1) + .all(|s| matches!(s.is_sorted_flag(), IsSorted::Not)) + { + if state.verbose() { + eprintln!("FOUND SORTED KEY: running default HASH AGGREGATION") + } + Ok(false) + } else if std::env::var("POLARS_NO_PARTITION").is_ok() { if state.verbose() { eprintln!("POLARS_NO_PARTITION set: running default HASH AGGREGATION") } @@ -174,9 +189,9 @@ fn can_run_partitioned( eprintln!("POLARS_FORCE_PARTITION set: running partitioned HASH AGGREGATION") } Ok(true) - } else if original_df.height() < 1000 && !cfg!(test) { + } else if original_df.height() < PARTITION_LIMIT && !cfg!(test) { if state.verbose() { - eprintln!("DATAFRAME < 1000 rows: running default HASH AGGREGATION") + eprintln!("DATAFRAME < {PARTITION_LIMIT} rows: running default HASH AGGREGATION") } Ok(false) } else { @@ -207,10 +222,14 @@ fn can_run_partitioned( if from_partitioned_ds { let estimated_cardinality = unique_estimate as f32 / original_df.height() as f32; if estimated_cardinality < 0.4 { - eprintln!("PARTITIONED DS"); + if state.verbose() { + eprintln!("PARTITIONED DS"); + } Ok(true) } else { - eprintln!("PARTITIONED DS: estimated cardinality: {estimated_cardinality} exceeded the boundary: 0.4, running default HASH AGGREGATION"); + if state.verbose() { + eprintln!("PARTITIONED DS: estimated cardinality: {estimated_cardinality} exceeded the boundary: 0.4, running default HASH AGGREGATION"); + } Ok(false) } } else if unique_estimate > unique_count_boundary { @@ -257,6 +276,7 @@ impl PartitionGroupByExec { &mut vec![], false, false, + true, ) .unwrap(); @@ -297,7 +317,7 @@ impl PartitionGroupByExec { } #[cfg(feature = "streaming")] - if !self.maintain_order { + if !self.maintain_order && std::env::var("POLARS_NO_STREAMING_GROUPBY").is_err() { if let Some(out) = self.run_streaming(state, original_df.clone()) { return out; } diff --git a/crates/polars-lazy/src/physical_plan/executors/group_by_rolling.rs b/crates/polars-lazy/src/physical_plan/executors/group_by_rolling.rs index cc2f1e2b677c0..74e21c4d11aaf 100644 --- a/crates/polars-lazy/src/physical_plan/executors/group_by_rolling.rs +++ b/crates/polars-lazy/src/physical_plan/executors/group_by_rolling.rs @@ -1,8 +1,3 @@ -#[cfg(feature = "dynamic_group_by")] -use polars_core::frame::group_by::GroupBy; -#[cfg(feature = "dynamic_group_by")] -use polars_time::RollingGroupOptions; - use super::*; #[cfg_attr(not(feature = "dynamic_group_by"), allow(dead_code))] diff --git a/crates/polars-lazy/src/physical_plan/executors/projection_utils.rs b/crates/polars-lazy/src/physical_plan/executors/projection_utils.rs index 464385681f043..14731d2a0adb7 100644 --- a/crates/polars-lazy/src/physical_plan/executors/projection_utils.rs +++ b/crates/polars-lazy/src/physical_plan/executors/projection_utils.rs @@ -316,7 +316,7 @@ pub(super) fn check_expand_literals( .collect::>()? } - let df = DataFrame::new_no_checks(selected_columns); + let df = unsafe { DataFrame::new_no_checks(selected_columns) }; // a literal could be projected to a zero length dataframe. // This prevents a panic. diff --git a/crates/polars-lazy/src/physical_plan/executors/python_scan.rs b/crates/polars-lazy/src/physical_plan/executors/python_scan.rs index 85ed618f79f0a..1df5ad7861ef2 100644 --- a/crates/polars-lazy/src/physical_plan/executors/python_scan.rs +++ b/crates/polars-lazy/src/physical_plan/executors/python_scan.rs @@ -21,7 +21,7 @@ impl Executor for PythonScanExec { let n_rows = self.options.n_rows.take(); Python::with_gil(|py| { let pl = PyModule::import(py, "polars").unwrap(); - let utils = pl.getattr("utils").unwrap(); + let utils = pl.getattr("_utils").unwrap(); let callable = utils.getattr("_execute_from_rust").unwrap(); let python_scan_function = self.options.scan_fn.take().unwrap().0; diff --git a/crates/polars-lazy/src/physical_plan/executors/scan/csv.rs b/crates/polars-lazy/src/physical_plan/executors/scan/csv.rs index bee325fc69d3f..56db2255f78be 100644 --- a/crates/polars-lazy/src/physical_plan/executors/scan/csv.rs +++ b/crates/polars-lazy/src/physical_plan/executors/scan/csv.rs @@ -12,16 +12,14 @@ pub struct CsvExec { impl CsvExec { fn read(&mut self) -> PolarsResult { - let mut with_columns = mem::take(&mut self.file_options.with_columns); - let mut projected_len = 0; - with_columns.as_ref().map(|columns| { - projected_len = columns.len(); - columns - }); + let with_columns = self + .file_options + .with_columns + .take() + // Interpret selecting no columns as selecting all columns. + .filter(|columns| !columns.is_empty()) + .map(Arc::unwrap_or_clone); - if projected_len == 0 { - with_columns = None; - } let n_rows = _set_n_rows_for_scan(self.file_options.n_rows); let predicate = self.predicate.clone().map(phys_expr_to_io_expr); @@ -33,7 +31,7 @@ impl CsvExec { .with_ignore_errors(self.options.ignore_errors) .with_skip_rows(self.options.skip_rows) .with_n_rows(n_rows) - .with_columns(with_columns.map(|mut cols| std::mem::take(Arc::make_mut(&mut cols)))) + .with_columns(with_columns) .low_memory(self.options.low_memory) .with_null_values(std::mem::take(&mut self.options.null_values)) .with_predicate(predicate) diff --git a/crates/polars-lazy/src/physical_plan/executors/scan/ipc.rs b/crates/polars-lazy/src/physical_plan/executors/scan/ipc.rs index 08f37ab566aaf..123eeb9461e16 100644 --- a/crates/polars-lazy/src/physical_plan/executors/scan/ipc.rs +++ b/crates/polars-lazy/src/physical_plan/executors/scan/ipc.rs @@ -1,5 +1,10 @@ use std::path::PathBuf; +use polars_core::config::env_force_async; +#[cfg(feature = "cloud")] +use polars_io::cloud::CloudOptions; +use polars_io::is_cloud_url; + use super::*; pub struct IpcExec { @@ -8,10 +13,43 @@ pub struct IpcExec { pub(crate) predicate: Option>, pub(crate) options: IpcScanOptions, pub(crate) file_options: FileScanOptions, + #[cfg(feature = "cloud")] + pub(crate) cloud_options: Option, + pub(crate) metadata: Option, } impl IpcExec { fn read(&mut self, verbose: bool) -> PolarsResult { + let is_cloud = is_cloud_url(&self.path); + let force_async = env_force_async(); + + let mut out = if is_cloud || force_async { + #[cfg(not(feature = "cloud"))] + { + panic!("activate cloud feature") + } + + #[cfg(feature = "cloud")] + { + if !is_cloud && verbose { + eprintln!("ASYNC READING FORCED"); + } + + polars_io::pl_async::get_runtime() + .block_on_potential_spawn(self.read_async(verbose))? + } + } else { + self.read_sync(verbose)? + }; + + if self.file_options.rechunk { + out.as_single_chunk_par(); + } + + Ok(out) + } + + fn read_sync(&mut self, verbose: bool) -> PolarsResult { let file = std::fs::File::open(&self.path)?; let (projection, predicate) = prepare_scan_args( self.predicate.clone(), @@ -28,6 +66,26 @@ impl IpcExec { .memory_mapped(self.options.memmap) .finish_with_scan_ops(predicate, verbose) } + + #[cfg(feature = "cloud")] + async fn read_async(&mut self, verbose: bool) -> PolarsResult { + let predicate = self.predicate.clone().map(phys_expr_to_io_expr); + + let reader = + IpcReaderAsync::from_uri(self.path.to_str().unwrap(), self.cloud_options.as_ref()) + .await?; + reader + .data( + self.metadata.as_ref(), + IpcReadOptions::default() + .with_row_limit(self.file_options.n_rows) + .with_row_index(self.file_options.row_index.clone()) + .with_projection(self.file_options.with_columns.as_deref().cloned()) + .with_predicate(predicate), + verbose, + ) + .await + } } impl Executor for IpcExec { diff --git a/crates/polars-lazy/src/physical_plan/executors/scan/ndjson.rs b/crates/polars-lazy/src/physical_plan/executors/scan/ndjson.rs index 9e8101052a4ac..40b66fef3ae64 100644 --- a/crates/polars-lazy/src/physical_plan/executors/scan/ndjson.rs +++ b/crates/polars-lazy/src/physical_plan/executors/scan/ndjson.rs @@ -1,5 +1,4 @@ use super::*; -use crate::prelude::{AnonymousScan, LazyJsonLineReader}; impl AnonymousScan for LazyJsonLineReader { fn as_any(&self) -> &dyn std::any::Any { diff --git a/crates/polars-lazy/src/physical_plan/executors/scan/parquet.rs b/crates/polars-lazy/src/physical_plan/executors/scan/parquet.rs index 780eaea8fec4e..ef5d7f8da914b 100644 --- a/crates/polars-lazy/src/physical_plan/executors/scan/parquet.rs +++ b/crates/polars-lazy/src/physical_plan/executors/scan/parquet.rs @@ -1,10 +1,10 @@ use std::path::PathBuf; +use polars_core::config::env_force_async; #[cfg(feature = "cloud")] use polars_core::config::{get_file_prefetch_size, verbose}; use polars_core::utils::accumulate_dataframes_vertical; use polars_io::cloud::CloudOptions; -use polars_io::parquet::FileMetaData; use polars_io::{is_cloud_url, RowIndex}; use super::*; @@ -312,6 +312,17 @@ impl ParquetExec { } fn read(&mut self) -> PolarsResult { + // FIXME: The row index implementation is incorrect when a predicate is + // applied. This code mitigates that by applying the predicate after the + // collection of the entire dataframe if a row index is requested. This is + // inefficient. + let post_predicate = self + .file_options + .row_index + .as_ref() + .and_then(|_| self.predicate.take()) + .map(phys_expr_to_io_expr); + let is_cloud = match self.paths.first() { Some(p) => is_cloud_url(p.as_path()), None => { @@ -335,7 +346,7 @@ impl ParquetExec { )); }, }; - let force_async = std::env::var("POLARS_FORCE_ASYNC").as_deref().unwrap_or("") == "1"; + let force_async = env_force_async(); let out = if is_cloud || force_async { #[cfg(not(feature = "cloud"))] @@ -352,6 +363,9 @@ impl ParquetExec { }; let mut out = accumulate_dataframes_vertical(out)?; + + polars_io::predicates::apply_predicate(&mut out, post_predicate.as_deref(), true)?; + if self.file_options.rechunk { out.as_single_chunk_par(); } diff --git a/crates/polars-lazy/src/physical_plan/exotic.rs b/crates/polars-lazy/src/physical_plan/exotic.rs index bef1e42b5fc4d..138f1f566eaa2 100644 --- a/crates/polars-lazy/src/physical_plan/exotic.rs +++ b/crates/polars-lazy/src/physical_plan/exotic.rs @@ -31,7 +31,8 @@ pub(crate) fn prepare_expression_for_context( // create a dummy lazyframe and run a very simple optimization run so that // type coercion and simplify expression optimizations run. let column = Series::full_null(name, 0, dtype); - let lf = DataFrame::new_no_checks(vec![column]) + let lf = column + .into_frame() .lazy() .without_optimizations() .with_simplify_expr(true) diff --git a/crates/polars-lazy/src/physical_plan/expressions/aggregation.rs b/crates/polars-lazy/src/physical_plan/expressions/aggregation.rs index 9ac22d27399cc..5393d2a6a4a03 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/aggregation.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/aggregation.rs @@ -1,12 +1,9 @@ use std::borrow::Cow; -use std::sync::Arc; use arrow::array::*; use arrow::compute::concatenate::concatenate; -use arrow::legacy::prelude::QuantileInterpolOptions; use arrow::legacy::utils::CustomIterTools; use arrow::offset::Offsets; -use polars_core::frame::group_by::{GroupByMethod, GroupsProxy}; use polars_core::prelude::*; use polars_core::utils::NoNull; #[cfg(feature = "dtype-struct")] @@ -15,7 +12,6 @@ use polars_core::POOL; use polars_ops::prelude::nan_propagating_aggregate; use crate::physical_plan::state::ExecutionState; -use crate::physical_plan::PartitionedAggregation; use crate::prelude::AggState::{AggregatedList, AggregatedScalar}; use crate::prelude::*; diff --git a/crates/polars-lazy/src/physical_plan/expressions/alias.rs b/crates/polars-lazy/src/physical_plan/expressions/alias.rs index 44a84e96ddb56..c715083b01f46 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/alias.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/alias.rs @@ -1,6 +1,3 @@ -use std::sync::Arc; - -use polars_core::frame::group_by::GroupsProxy; use polars_core::prelude::*; use crate::physical_plan::state::ExecutionState; diff --git a/crates/polars-lazy/src/physical_plan/expressions/apply.rs b/crates/polars-lazy/src/physical_plan/expressions/apply.rs index 7ceb6a2c5ec41..1bf4654128949 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/apply.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/apply.rs @@ -1,15 +1,11 @@ use std::borrow::Cow; -use std::sync::Arc; -use polars_core::frame::group_by::GroupsProxy; use polars_core::prelude::*; use polars_core::POOL; #[cfg(feature = "parquet")] use polars_io::predicates::{BatchStats, StatsEvaluator}; #[cfg(feature = "is_between")] use polars_ops::prelude::ClosedInterval; -#[cfg(feature = "parquet")] -use polars_plan::dsl::FunctionExpr; use rayon::prelude::*; use crate::physical_plan::state::ExecutionState; diff --git a/crates/polars-lazy/src/physical_plan/expressions/binary.rs b/crates/polars-lazy/src/physical_plan/expressions/binary.rs index c244c0f9bb002..f3b3d4e2f51be 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/binary.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/binary.rs @@ -1,6 +1,3 @@ -use std::sync::Arc; - -use polars_core::frame::group_by::GroupsProxy; use polars_core::prelude::*; use polars_core::POOL; #[cfg(feature = "round_series")] @@ -397,7 +394,7 @@ mod stats { } } - let dummy = DataFrame::new_no_checks(vec![]); + let dummy = DataFrame::empty(); let state = ExecutionState::new(); let out = match (self.left.is_literal(), self.right.is_literal()) { diff --git a/crates/polars-lazy/src/physical_plan/expressions/cast.rs b/crates/polars-lazy/src/physical_plan/expressions/cast.rs index 962ac5086a86e..32ad204ba8675 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/cast.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/cast.rs @@ -1,6 +1,3 @@ -use std::sync::Arc; - -use polars_core::frame::group_by::GroupsProxy; use polars_core::prelude::*; use crate::physical_plan::state::ExecutionState; diff --git a/crates/polars-lazy/src/physical_plan/expressions/column.rs b/crates/polars-lazy/src/physical_plan/expressions/column.rs index d4acf8a309bc7..eda761ab9d56c 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/column.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/column.rs @@ -1,7 +1,5 @@ use std::borrow::Cow; -use std::sync::Arc; -use polars_core::frame::group_by::GroupsProxy; use polars_core::prelude::*; use polars_plan::constants::CSE_REPLACED; diff --git a/crates/polars-lazy/src/physical_plan/expressions/filter.rs b/crates/polars-lazy/src/physical_plan/expressions/filter.rs index e6adb24953e81..b2cfe43e39975 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/filter.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/filter.rs @@ -1,7 +1,4 @@ -use std::sync::Arc; - use arrow::legacy::is_valid::IsValid; -use polars_core::frame::group_by::GroupsProxy; use polars_core::prelude::*; use polars_core::POOL; use polars_utils::idx_vec::IdxVec; diff --git a/crates/polars-lazy/src/physical_plan/expressions/literal.rs b/crates/polars-lazy/src/physical_plan/expressions/literal.rs index a0618b13751c4..cf33aa81c9a71 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/literal.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/literal.rs @@ -1,7 +1,6 @@ use std::borrow::Cow; use std::ops::Deref; -use polars_core::frame::group_by::GroupsProxy; use polars_core::prelude::*; use polars_core::utils::NoNull; use polars_plan::dsl::consts::LITERAL_NAME; diff --git a/crates/polars-lazy/src/physical_plan/expressions/mod.rs b/crates/polars-lazy/src/physical_plan/expressions/mod.rs index 2da66fd60374d..4642654a9fb6c 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/mod.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/mod.rs @@ -31,7 +31,6 @@ pub(crate) use column::*; pub(crate) use count::*; pub(crate) use filter::*; pub(crate) use literal::*; -use polars_core::frame::group_by::GroupsProxy; use polars_core::prelude::*; use polars_io::predicates::PhysicalIoExpr; #[cfg(feature = "dynamic_group_by")] diff --git a/crates/polars-lazy/src/physical_plan/expressions/slice.rs b/crates/polars-lazy/src/physical_plan/expressions/slice.rs index 13793e55ac346..3d0129675a96b 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/slice.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/slice.rs @@ -1,6 +1,3 @@ -use std::sync::Arc; - -use polars_core::frame::group_by::{GroupsProxy, IdxItem}; use polars_core::prelude::*; use polars_core::utils::{slice_offsets, CustomIterTools}; use polars_core::POOL; diff --git a/crates/polars-lazy/src/physical_plan/expressions/sort.rs b/crates/polars-lazy/src/physical_plan/expressions/sort.rs index 77709c9d8a031..0df7d4b94ab97 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/sort.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/sort.rs @@ -1,6 +1,3 @@ -use std::sync::Arc; - -use polars_core::frame::group_by::GroupsProxy; use polars_core::prelude::*; use polars_core::POOL; use polars_ops::chunked_array::ListNameSpaceImpl; diff --git a/crates/polars-lazy/src/physical_plan/expressions/sortby.rs b/crates/polars-lazy/src/physical_plan/expressions/sortby.rs index b6ef95d46a004..d213af7631edc 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/sortby.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/sortby.rs @@ -1,6 +1,3 @@ -use std::sync::Arc; - -use polars_core::frame::group_by::{GroupsIndicator, GroupsProxy}; use polars_core::prelude::*; use polars_core::POOL; use polars_utils::idx_vec::IdxVec; diff --git a/crates/polars-lazy/src/physical_plan/expressions/take.rs b/crates/polars-lazy/src/physical_plan/expressions/take.rs index b6b20ff5830e8..9408635de3329 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/take.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/take.rs @@ -1,8 +1,5 @@ -use std::sync::Arc; - use arrow::legacy::utils::CustomIterTools; use polars_core::chunked_array::builder::get_list_builder; -use polars_core::frame::group_by::GroupsProxy; use polars_core::prelude::*; use polars_core::utils::NoNull; use polars_ops::prelude::{convert_to_unsigned_index, is_positive_idx_uncertain}; diff --git a/crates/polars-lazy/src/physical_plan/expressions/ternary.rs b/crates/polars-lazy/src/physical_plan/expressions/ternary.rs index aed1b74cf710e..d52cb4eb8d61b 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/ternary.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/ternary.rs @@ -1,6 +1,3 @@ -use std::sync::Arc; - -use polars_core::frame::group_by::GroupsProxy; use polars_core::prelude::*; use polars_core::POOL; diff --git a/crates/polars-lazy/src/physical_plan/expressions/window.rs b/crates/polars-lazy/src/physical_plan/expressions/window.rs index 95acb8ee06398..8cefbb4cde34b 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/window.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/window.rs @@ -1,16 +1,12 @@ use std::fmt::Write; -use std::sync::Arc; use arrow::array::PrimitiveArray; use polars_core::export::arrow::bitmap::Bitmap; -use polars_core::frame::group_by::{GroupBy, GroupsProxy}; use polars_core::prelude::*; use polars_core::series::IsSorted; use polars_core::utils::_split_offsets; use polars_core::{downcast_as_macro_arg_physical, POOL}; -use polars_ops::frame::join::{ - default_join_ids, private_left_join_multiple_keys, ChunkJoinOptIds, JoinValidation, -}; +use polars_ops::frame::join::{default_join_ids, private_left_join_multiple_keys, ChunkJoinOptIds}; use polars_ops::frame::SeriesJoin; use polars_utils::format_smartstring; use polars_utils::sort::perfect_sort; @@ -18,7 +14,6 @@ use polars_utils::sync::SyncPtr; use rayon::prelude::*; use super::*; -use crate::physical_plan::state::ExecutionState; use crate::prelude::*; pub struct WindowExpr { @@ -567,8 +562,8 @@ impl PhysicalExpr for WindowExpr { .unwrap() .1 } else { - let df_right = DataFrame::new_no_checks(keys); - let df_left = DataFrame::new_no_checks(group_by_columns); + let df_right = unsafe { DataFrame::new_no_checks(keys) }; + let df_left = unsafe { DataFrame::new_no_checks(group_by_columns) }; private_left_join_multiple_keys( &df_left, &df_right, None, None, true, ) @@ -635,7 +630,7 @@ fn materialize_column(join_opt_ids: &ChunkJoinOptIds, out_column: &Series) -> Se match join_opt_ids { Either::Left(ids) => unsafe { - out_column.take_unchecked(&ids.iter().copied().collect_ca("")) + IdxCa::with_nullable_idx(ids, |idx| out_column.take_unchecked(idx)) }, Either::Right(ids) => unsafe { out_column.take_opt_chunked_unchecked(ids) }, } diff --git a/crates/polars-lazy/src/physical_plan/file_cache.rs b/crates/polars-lazy/src/physical_plan/file_cache.rs index 5ea1074d95b0e..ee7c8e8ddffa7 100644 --- a/crates/polars-lazy/src/physical_plan/file_cache.rs +++ b/crates/polars-lazy/src/physical_plan/file_cache.rs @@ -1,13 +1,6 @@ use std::sync::Mutex; use polars_core::prelude::*; -#[cfg(any( - feature = "parquet", - feature = "csv", - feature = "ipc", - feature = "json" -))] -use polars_plan::logical_plan::FileFingerPrint; use crate::prelude::*; diff --git a/crates/polars-lazy/src/physical_plan/node_timer.rs b/crates/polars-lazy/src/physical_plan/node_timer.rs index 8be6861dda398..4926f7df8c592 100644 --- a/crates/polars-lazy/src/physical_plan/node_timer.rs +++ b/crates/polars-lazy/src/physical_plan/node_timer.rs @@ -1,4 +1,4 @@ -use std::sync::{Arc, Mutex}; +use std::sync::Mutex; use std::time::Instant; use polars_core::prelude::*; @@ -57,10 +57,8 @@ impl NodeTimer { let mut end = end.into_inner(); end.rename("end"); - DataFrame::new_no_checks(vec![nodes_s, start.into_series(), end.into_series()]).sort( - vec!["start"], - vec![false], - false, - ) + let columns = vec![nodes_s, start.into_series(), end.into_series()]; + let df = unsafe { DataFrame::new_no_checks(columns) }; + df.sort(vec!["start"], vec![false], false) } } diff --git a/crates/polars-lazy/src/physical_plan/planner/expr.rs b/crates/polars-lazy/src/physical_plan/planner/expr.rs index 0489bf40c257a..26e5c920ca6c5 100644 --- a/crates/polars-lazy/src/physical_plan/planner/expr.rs +++ b/crates/polars-lazy/src/physical_plan/planner/expr.rs @@ -1,4 +1,3 @@ -use polars_core::frame::group_by::GroupByMethod; use polars_core::prelude::*; use polars_core::series::IsSorted; use polars_core::utils::_split_offsets; diff --git a/crates/polars-lazy/src/physical_plan/planner/lp.rs b/crates/polars-lazy/src/physical_plan/planner/lp.rs index ea47cf3308dcd..dc6cbc81b255e 100644 --- a/crates/polars-lazy/src/physical_plan/planner/lp.rs +++ b/crates/polars-lazy/src/physical_plan/planner/lp.rs @@ -34,8 +34,15 @@ fn partitionable_gb( if partitionable { for agg in aggs { - let aexpr = expr_arena.get(*agg); - let depth = (expr_arena).iter(*agg).count(); + let mut agg = *agg; + let mut aexpr = expr_arena.get(agg); + // It should end with an aggregation + if let AExpr::Alias(input, _) = aexpr { + agg = *input; + aexpr = expr_arena.get(agg); + } + + let depth = (expr_arena).iter(agg).count(); // These single expressions are partitionable if matches!(aexpr, AExpr::Len) { @@ -48,29 +55,13 @@ fn partitionable_gb( break; } - // it should end with an aggregation - if let AExpr::Alias(input, _) = aexpr { - // col().agg().alias() is allowed: count of 3 - // col().alias() is not allowed: count of 2 - // count().alias() is allowed: count of 2 - if depth <= 2 { - match expr_arena.get(*input) { - AExpr::Len => {}, - _ => { - partitionable = false; - break; - }, - } - } - } - let has_aggregation = |node: Node| has_aexpr(node, expr_arena, |ae| matches!(ae, AExpr::Agg(_))); // check if the aggregation type is partitionable // only simple aggregation like col().sum // that can be divided in to the aggregation of their partitions are allowed - if !((expr_arena).iter(*agg).all(|(_, ae)| { + if !((expr_arena).iter(agg).all(|(_, ae)| { use AExpr::*; match ae { // struct is needed to keep both states @@ -78,7 +69,7 @@ fn partitionable_gb( Agg(AAggExpr::Mean(_)) => { // only numeric means for now. // logical types seem to break because of casts to float. - matches!(expr_arena.get(*agg).get_type(_input_schema, Context::Default, expr_arena).map(|dt| { + matches!(expr_arena.get(agg).get_type(_input_schema, Context::Default, expr_arena).map(|dt| { dt.is_numeric()}), Ok(true)) }, // only allowed expressions @@ -120,7 +111,7 @@ fn partitionable_gb( #[cfg(feature = "object")] { - for name in aexpr_to_leaf_names(*agg, expr_arena) { + for name in aexpr_to_leaf_names(agg, expr_arena) { let dtype = _input_schema.get(&name).unwrap(); if let DataType::Object(_, _) = dtype { @@ -241,7 +232,12 @@ pub fn create_physical_plan( })) }, #[cfg(feature = "ipc")] - FileScan::Ipc { options } => { + FileScan::Ipc { + options, + #[cfg(feature = "cloud")] + cloud_options, + metadata, + } => { assert_eq!(paths.len(), 1); let path = paths[0].clone(); Ok(Box::new(executors::IpcExec { @@ -250,6 +246,9 @@ pub fn create_physical_plan( predicate, options, file_options, + #[cfg(feature = "cloud")] + cloud_options, + metadata, })) }, #[cfg(feature = "parquet")] diff --git a/crates/polars-lazy/src/physical_plan/state.rs b/crates/polars-lazy/src/physical_plan/state.rs index 5112e39eb2db8..c946399d5017e 100644 --- a/crates/polars-lazy/src/physical_plan/state.rs +++ b/crates/polars-lazy/src/physical_plan/state.rs @@ -5,7 +5,6 @@ use std::sync::{Mutex, RwLock}; use bitflags::bitflags; use once_cell::sync::OnceCell; use polars_core::config::verbose; -use polars_core::frame::group_by::GroupsProxy; use polars_core::prelude::*; use polars_ops::prelude::ChunkJoinOptIds; #[cfg(any( diff --git a/crates/polars-lazy/src/physical_plan/streaming/checks.rs b/crates/polars-lazy/src/physical_plan/streaming/checks.rs index fc8b8f2e1ad2b..7dd0adc328518 100644 --- a/crates/polars-lazy/src/physical_plan/streaming/checks.rs +++ b/crates/polars-lazy/src/physical_plan/streaming/checks.rs @@ -79,7 +79,7 @@ pub(super) fn streamable_join(args: &JoinArgs) -> bool { let supported = match args.how { #[cfg(feature = "cross_join")] JoinType::Cross => true, - JoinType::Inner | JoinType::Left => true, + JoinType::Inner | JoinType::Left | JoinType::Outer { .. } => true, _ => false, }; supported && !args.validation.needs_checks() diff --git a/crates/polars-lazy/src/physical_plan/streaming/construct_pipeline.rs b/crates/polars-lazy/src/physical_plan/streaming/construct_pipeline.rs index 506760c9744d7..6aac52bb97c2c 100644 --- a/crates/polars-lazy/src/physical_plan/streaming/construct_pipeline.rs +++ b/crates/polars-lazy/src/physical_plan/streaming/construct_pipeline.rs @@ -1,16 +1,16 @@ use std::any::Any; use std::cell::RefCell; use std::rc::Rc; -use std::sync::Arc; use polars_core::config::verbose; use polars_core::prelude::*; use polars_io::predicates::{PhysicalIoExpr, StatsEvaluator}; use polars_pipe::expressions::PhysicalPipedExpr; use polars_pipe::operators::chunks::DataChunk; -use polars_pipe::pipeline::{create_pipeline, get_dummy_operator, get_operator, PipeLine}; +use polars_pipe::pipeline::{ + create_pipeline, execute_pipeline, get_dummy_operator, get_operator, CallBacks, PipeLine, +}; use polars_pipe::SExecutionContext; -use polars_utils::IdxSize; use crate::physical_plan::planner::{create_physical_expr, ExpressionConversionState}; use crate::physical_plan::state::ExecutionState; @@ -107,32 +107,42 @@ pub(super) fn construct( use ALogicalPlan::*; let mut pipelines = Vec::with_capacity(tree.len()); + let mut callbacks = CallBacks::new(); let is_verbose = verbose(); - // first traverse the branches and nodes to determine how often a sink is - // shared - // this shared count will be used in the pipeline to determine + // First traverse the branches and nodes to determine how often a sink is + // shared. + // This shared count will be used in the pipeline to determine // when the sink can be finalized. let mut sink_share_count = PlHashMap::new(); let n_branches = tree.len(); if n_branches > 1 { for branch in &tree { - for sink in branch.iter_sinks() { - let count = sink_share_count - .entry(sink.0) - .or_insert(Rc::new(RefCell::new(0u32))); - *count.borrow_mut() += 1; + for op in branch.operators_sinks.iter() { + match op { + PipelineNode::Sink(sink) => { + let count = sink_share_count + .entry(sink.0) + .or_insert(Rc::new(RefCell::new(0u32))); + *count.borrow_mut() += 1; + }, + PipelineNode::RhsJoin(node) => { + let _ = callbacks.insert(*node, get_dummy_operator()); + }, + _ => {}, + } } } } - // shared sinks are stored in a cache, so that they share info + // Shared sinks are stored in a cache, so that they share state. + // If the shared sink is already in cache, that one is used. let mut sink_cache = PlHashMap::new(); let mut final_sink = None; for branch in tree { - // the file sink is always to the top of the tree + // The file sink is always to the top of the tree // not every branch has a final sink. For instance rhs join branches if let Some(node) = branch.get_final_sink() { if matches!(lp_arena.get(node), ALogicalPlan::Sink { .. }) { @@ -174,32 +184,26 @@ pub(super) fn construct( PipelineNode::RhsJoin(node) => { operator_nodes.push(node); jit_insert_slice(node, lp_arena, &mut sink_nodes, operator_offset); - let op = get_dummy_operator(); - operators.push(op) + let op = callbacks.get(&node).unwrap().clone(); + operators.push(Box::new(op)) }, } } - let execution_id = branch.execution_id; let pipeline = create_pipeline( &branch.sources, operators, - operator_nodes, sink_nodes, lp_arena, expr_arena, to_physical_piped_expr, is_verbose, &mut sink_cache, + &mut callbacks, )?; - pipelines.push((execution_id, pipeline)); + pipelines.push(pipeline); } - // We sort to ensure we execute in the stack traversal order. - // this is important to make unions and joins work as expected - // also pipelines are not ready to receive inputs otherwise - pipelines.sort_by(|a, b| a.0.cmp(&b.0)); - let Some(final_sink) = final_sink else { return Ok(None); }; @@ -224,18 +228,12 @@ pub(super) fn construct( None }; - let Some((_, mut most_left)) = pipelines.pop() else { - unreachable!() - }; - while let Some((_, rhs)) = pipelines.pop() { - most_left = most_left.with_other_branch(rhs) - } - // replace the part of the logical plan with a `MapFunction` that will execute the pipeline. + // Replace the part of the logical plan with a `MapFunction` that will execute the pipeline. let schema = lp_arena .get(insertion_location) .schema(lp_arena) .into_owned(); - let pipeline_node = get_pipeline_node(lp_arena, most_left, schema, original_lp); + let pipeline_node = get_pipeline_node(lp_arena, pipelines, schema, original_lp); lp_arena.replace(insertion_location, pipeline_node); Ok(Some(final_sink)) @@ -253,7 +251,7 @@ impl SExecutionContext for ExecutionState { fn get_pipeline_node( lp_arena: &mut Arena, - mut pipeline: PipeLine, + mut pipelines: Vec, schema: SchemaRef, original_lp: Option, ) -> ALogicalPlan { @@ -272,11 +270,12 @@ fn get_pipeline_node( function: Arc::new(move |_df: DataFrame| { let mut state = ExecutionState::new(); if state.verbose() { - eprintln!("RUN STREAMING PIPELINE") + eprintln!("RUN STREAMING PIPELINE"); + eprintln!("{:?}", &pipelines) } state.set_in_streaming_engine(); let state = Box::new(state) as Box; - pipeline.execute(state) + execute_pipeline(state, std::mem::take(&mut pipelines)) }), schema, original: original_lp.map(Arc::new), diff --git a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs index f5bb0f50e3a47..739bcb4f8545a 100644 --- a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs +++ b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs @@ -1,4 +1,3 @@ -use polars_core::error::PolarsResult; use polars_core::prelude::*; use polars_pipe::pipeline::swap_join_order; use polars_plan::prelude::*; @@ -106,12 +105,16 @@ pub(crate) fn insert_streaming_nodes( // whether the full plan needs to be translated // to streaming allow_partial: bool, + row_estimate: bool, ) -> PolarsResult { scratch.clear(); - // this is needed to determine which side of the joins should be - // traversed first - set_estimated_row_counts(root, lp_arena, expr_arena, 0, scratch); + // This is needed to determine which side of the joins should be + // traversed first. As we want to keep the smallest table in the build phase as that keeps most + // data in memory. + if row_estimate { + set_estimated_row_counts(root, lp_arena, expr_arena, 0, scratch); + } scratch.clear(); @@ -284,15 +287,15 @@ pub(crate) fn insert_streaming_nodes( }; let mut state_left = state.split(); - // rhs is second, so that is first on the stack + // Rhs is second, so that is first on the stack. let mut state_right = state; state_right.join_count = 0; state_right .operators_sinks .push(PipelineNode::RhsJoin(root)); - // we want to traverse lhs last, so push it first on the stack - // rhs is a new pipeline + // We want to traverse lhs last, so push it first on the stack + // rhs is a new pipeline. state_left.operators_sinks.push(PipelineNode::Sink(root)); stack.push(StackFrame::new(input_left, state_left, current_idx)); stack.push(StackFrame::new(input_right, state_right, current_idx)); @@ -328,7 +331,7 @@ pub(crate) fn insert_streaming_nodes( }; for (i, input) in inputs.iter().enumerate() { let mut state = if i == 0 { - // note the clone! + // Note the clone! let mut state = state.clone(); state.join_count += inputs.len() as u32 - 1; state diff --git a/crates/polars-lazy/src/physical_plan/streaming/tree.rs b/crates/polars-lazy/src/physical_plan/streaming/tree.rs index d948ab3664054..0405884207a5d 100644 --- a/crates/polars-lazy/src/physical_plan/streaming/tree.rs +++ b/crates/polars-lazy/src/physical_plan/streaming/tree.rs @@ -1,11 +1,7 @@ use std::collections::BTreeSet; use std::fmt::Debug; -#[cfg(debug_assertions)] use polars_plan::prelude::*; -#[cfg(debug_assertions)] -use polars_utils::arena::Arena; -use polars_utils::arena::Node; #[derive(Copy, Clone, Debug)] pub(super) enum PipelineNode { @@ -58,10 +54,6 @@ impl Branch { // so the first sink is the final one. self.operators_sinks.iter().find_map(sink_node) } - pub(super) fn iter_sinks(&self) -> impl Iterator + '_ { - self.operators_sinks.iter().flat_map(sink_node) - } - pub(super) fn split(&self) -> Self { Self { execution_id: self.execution_id, diff --git a/crates/polars-lazy/src/scan/csv.rs b/crates/polars-lazy/src/scan/csv.rs index 99c3495605ceb..3f6c3665f35ad 100644 --- a/crates/polars-lazy/src/scan/csv.rs +++ b/crates/polars-lazy/src/scan/csv.rs @@ -6,7 +6,6 @@ use polars_io::csv::{CommentPrefix, CsvEncoding, NullValues}; use polars_io::utils::get_reader_bytes; use polars_io::RowIndex; -use crate::frame::LazyFileListReader; use crate::prelude::*; #[derive(Clone)] @@ -196,7 +195,7 @@ impl<'a> LazyCsvReader<'a> { self } - /// Reduce memory usage in expensive of performance + /// Reduce memory usage at the expense of performance #[must_use] pub fn low_memory(mut self, toggle: bool) -> Self { self.low_memory = toggle; @@ -332,6 +331,16 @@ impl LazyFileListReader for LazyCsvReader<'_> { self } + fn with_n_rows(mut self, n_rows: impl Into>) -> Self { + self.n_rows = n_rows.into(); + self + } + + fn with_row_index(mut self, row_index: impl Into>) -> Self { + self.row_index = row_index.into(); + self + } + fn rechunk(&self) -> bool { self.rechunk } diff --git a/crates/polars-lazy/src/scan/file_list_reader.rs b/crates/polars-lazy/src/scan/file_list_reader.rs index a7172ce9b74c9..8d7942cd9afe9 100644 --- a/crates/polars-lazy/src/scan/file_list_reader.rs +++ b/crates/polars-lazy/src/scan/file_list_reader.rs @@ -40,6 +40,10 @@ pub trait LazyFileListReader: Clone { .map(|r| { let path = r?; self.clone() + // Each individual reader should not apply a row limit. + .with_n_rows(None) + // Each individual reader should not apply a row index. + .with_row_index(None) .with_path(path.clone()) .with_rechunk(false) .finish_no_glob() @@ -100,6 +104,12 @@ pub trait LazyFileListReader: Clone { #[must_use] fn with_paths(self, paths: Arc<[PathBuf]>) -> Self; + /// Configure the row limit. + fn with_n_rows(self, n_rows: impl Into>) -> Self; + + /// Configure the row index. + fn with_row_index(self, row_index: impl Into>) -> Self; + /// Rechunk the memory to contiguous chunks when parsing is done. fn rechunk(&self) -> bool; diff --git a/crates/polars-lazy/src/scan/ipc.rs b/crates/polars-lazy/src/scan/ipc.rs index 653b7e368f91b..a849bec3d1c88 100644 --- a/crates/polars-lazy/src/scan/ipc.rs +++ b/crates/polars-lazy/src/scan/ipc.rs @@ -1,6 +1,7 @@ use std::path::{Path, PathBuf}; use polars_core::prelude::*; +use polars_io::cloud::CloudOptions; use polars_io::RowIndex; use crate::prelude::*; @@ -12,6 +13,8 @@ pub struct ScanArgsIpc { pub rechunk: bool, pub row_index: Option, pub memmap: bool, + #[cfg(feature = "cloud")] + pub cloud_options: Option, } impl Default for ScanArgsIpc { @@ -22,6 +25,8 @@ impl Default for ScanArgsIpc { rechunk: false, row_index: None, memmap: true, + #[cfg(feature = "cloud")] + cloud_options: Default::default(), } } } @@ -58,6 +63,8 @@ impl LazyFileListReader for LazyIpcReader { args.cache, args.row_index.clone(), args.rechunk, + #[cfg(feature = "cloud")] + args.cloud_options, )? .build() .into(); @@ -89,6 +96,16 @@ impl LazyFileListReader for LazyIpcReader { self } + fn with_n_rows(mut self, n_rows: impl Into>) -> Self { + self.args.n_rows = n_rows.into(); + self + } + + fn with_row_index(mut self, row_index: impl Into>) -> Self { + self.args.row_index = row_index.into(); + self + } + fn rechunk(&self) -> bool { self.args.rechunk } diff --git a/crates/polars-lazy/src/scan/ndjson.rs b/crates/polars-lazy/src/scan/ndjson.rs index ab9094295c23c..e2b2691e4e08f 100644 --- a/crates/polars-lazy/src/scan/ndjson.rs +++ b/crates/polars-lazy/src/scan/ndjson.rs @@ -77,7 +77,7 @@ impl LazyJsonLineReader { self } - /// Reduce memory usage in expensive of performance + /// Reduce memory usage at the expense of performance #[must_use] pub fn low_memory(mut self, toggle: bool) -> Self { self.low_memory = toggle; @@ -123,6 +123,16 @@ impl LazyFileListReader for LazyJsonLineReader { self } + fn with_n_rows(mut self, n_rows: impl Into>) -> Self { + self.n_rows = n_rows.into(); + self + } + + fn with_row_index(mut self, row_index: impl Into>) -> Self { + self.row_index = row_index.into(); + self + } + fn rechunk(&self) -> bool { self.rechunk } diff --git a/crates/polars-lazy/src/scan/parquet.rs b/crates/polars-lazy/src/scan/parquet.rs index 927a1c2f77ead..aa0dc47b9ca52 100644 --- a/crates/polars-lazy/src/scan/parquet.rs +++ b/crates/polars-lazy/src/scan/parquet.rs @@ -115,6 +115,16 @@ impl LazyFileListReader for LazyParquetReader { self } + fn with_n_rows(mut self, n_rows: impl Into>) -> Self { + self.args.n_rows = n_rows.into(); + self + } + + fn with_row_index(mut self, row_index: impl Into>) -> Self { + self.args.row_index = row_index.into(); + self + } + fn rechunk(&self) -> bool { self.args.rechunk } diff --git a/crates/polars-lazy/src/tests/io.rs b/crates/polars-lazy/src/tests/io.rs index 70aa4d41d7c89..24110424f20c5 100644 --- a/crates/polars-lazy/src/tests/io.rs +++ b/crates/polars-lazy/src/tests/io.rs @@ -5,6 +5,7 @@ use polars_ops::prelude::ClosedInterval; use super::*; #[test] +#[cfg(feature = "parquet")] fn test_parquet_exec() -> PolarsResult<()> { let _guard = SINGLE_LOCK.lock().unwrap(); // filter @@ -36,6 +37,7 @@ fn test_parquet_exec() -> PolarsResult<()> { } #[test] +#[cfg(all(feature = "parquet", feature = "is_between"))] fn test_parquet_statistics_no_skip() { let _guard = SINGLE_LOCK.lock().unwrap(); init_files(); @@ -109,6 +111,7 @@ fn test_parquet_statistics_no_skip() { } #[test] +#[cfg(all(feature = "parquet", feature = "is_between"))] fn test_parquet_statistics() -> PolarsResult<()> { let _guard = SINGLE_LOCK.lock().unwrap(); init_files(); @@ -413,6 +416,8 @@ fn test_ipc_globbing() -> PolarsResult<()> { rechunk: false, row_index: None, memmap: true, + #[cfg(feature = "cloud")] + cloud_options: None, }, )? .collect()?; diff --git a/crates/polars-lazy/src/tests/mod.rs b/crates/polars-lazy/src/tests/mod.rs index 058a40a9b38ef..be7deb7994682 100644 --- a/crates/polars-lazy/src/tests/mod.rs +++ b/crates/polars-lazy/src/tests/mod.rs @@ -30,7 +30,6 @@ fn load_df() -> DataFrame { } use std::io::Cursor; -use std::iter::FromIterator; use optimization_checks::*; use polars_core::chunked_array::builder::get_list_builder; @@ -42,7 +41,7 @@ use polars_core::prelude::*; pub(crate) use polars_core::SINGLE_LOCK; use polars_io::prelude::*; use polars_plan::logical_plan::{ - ArenaLpIter, OptimizationRule, SimplifyExprRule, StackOptimizer, TypeCoercionRule, + OptimizationRule, SimplifyExprRule, StackOptimizer, TypeCoercionRule, }; #[cfg(feature = "cov")] diff --git a/crates/polars-lazy/src/tests/queries.rs b/crates/polars-lazy/src/tests/queries.rs index 4d997343e68bf..c352c43e88ef0 100644 --- a/crates/polars-lazy/src/tests/queries.rs +++ b/crates/polars-lazy/src/tests/queries.rs @@ -1,10 +1,7 @@ -use polars_core::frame::explode::MeltArgs; #[cfg(feature = "diff")] use polars_core::series::ops::NullBehavior; use super::*; -#[cfg(feature = "range")] -use crate::dsl::arg_sort_by; #[test] fn test_lazy_with_column() { @@ -440,9 +437,9 @@ fn test_lazy_query_10() { let z: Series = DurationChunked::from_duration( "z", [ - ChronoDuration::hours(1), - ChronoDuration::hours(2), - ChronoDuration::hours(3), + ChronoDuration::try_hours(1).unwrap(), + ChronoDuration::try_hours(2).unwrap(), + ChronoDuration::try_hours(3).unwrap(), ], TimeUnit::Nanoseconds, ) diff --git a/crates/polars-lazy/src/tests/streaming.rs b/crates/polars-lazy/src/tests/streaming.rs index 756027bf0e3c7..a25a015a1e42e 100644 --- a/crates/polars-lazy/src/tests/streaming.rs +++ b/crates/polars-lazy/src/tests/streaming.rs @@ -213,6 +213,7 @@ fn test_streaming_inner_join3() -> PolarsResult<()> { assert_streaming_with_default(q, true, false); Ok(()) } + #[test] fn test_streaming_inner_join2() -> PolarsResult<()> { let lf_left = df![ @@ -388,3 +389,29 @@ fn test_sort_maintain_order_streaming() -> PolarsResult<()> { ]?)); Ok(()) } + +#[test] +fn test_streaming_outer_join() -> PolarsResult<()> { + let lf_left = df![ + "a"=> [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], + "b"=> [0, 0, 0, 3, 0, 1, 3, 3, 3, 1, 4, 4, 2, 1, 1, 3, 1, 4, 2, 2], + ]? + .lazy(); + + let lf_right = df![ + "a"=> [10, 18, 13, 9, 1, 13, 14, 12, 15, 11], + "b"=> [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + ]? + .lazy(); + + let q = lf_left + .outer_join(lf_right, col("a"), col("a")) + .sort_by_exprs([all()], [false], false, false); + + // Toggle so that the join order is swapped. + for toggle in [true, true] { + assert_streaming_with_default(q.clone().with_streaming(toggle), true, false); + } + + Ok(()) +} diff --git a/crates/polars-ops/Cargo.toml b/crates/polars-ops/Cargo.toml index fc05249d42523..2caf329d430a6 100644 --- a/crates/polars-ops/Cargo.toml +++ b/crates/polars-ops/Cargo.toml @@ -17,7 +17,7 @@ polars-utils = { workspace = true } ahash = { workspace = true } aho-corasick = { workspace = true, optional = true } -argminmax = { version = "0.6.1", default-features = false, features = ["float"] } +argminmax = { version = "0.6.2", default-features = false, features = ["float"] } arrow = { workspace = true } base64 = { workspace = true, optional = true } bytemuck = { workspace = true } @@ -27,7 +27,6 @@ either = { workspace = true } hashbrown = { workspace = true } hex = { workspace = true, optional = true } indexmap = { workspace = true } -jsonpath_lib = { version = "0.3", optional = true, git = "https://github.com/ritchie46/jsonpath", branch = "improve_compiled" } memchr = { workspace = true } num-traits = { workspace = true } rand = { workspace = true, optional = true, features = ["small_rng", "std"] } @@ -39,6 +38,11 @@ serde_json = { workspace = true, optional = true } smartstring = { workspace = true } unicode-reverse = { workspace = true, optional = true } +[dependencies.jsonpath_lib] +package = "jsonpath_lib_polars_vendor" +optional = true +version = "0.0.1" + [dev-dependencies] rand = { workspace = true, features = ["small_rng"] } @@ -109,7 +113,7 @@ top_k = [] pivot = ["polars-core/reinterpret"] cross_join = [] chunked_ids = [] -asof_join = ["polars-core/asof_join"] +asof_join = [] semi_anti_join = [] array_any_all = ["dtype-array"] array_count = ["dtype-array"] diff --git a/crates/polars-ops/src/chunked_array/array/dispersion.rs b/crates/polars-ops/src/chunked_array/array/dispersion.rs index 7cacfcf9aad34..e7039ac5db2e4 100644 --- a/crates/polars-ops/src/chunked_array/array/dispersion.rs +++ b/crates/polars-ops/src/chunked_array/array/dispersion.rs @@ -1,5 +1,3 @@ -use polars_core::datatypes::ArrayChunked; - use super::*; pub(super) fn median_with_nulls(ca: &ArrayChunked) -> PolarsResult { diff --git a/crates/polars-ops/src/chunked_array/array/get.rs b/crates/polars-ops/src/chunked_array/array/get.rs index 6cb5630676e92..f8fc2e894acf8 100644 --- a/crates/polars-ops/src/chunked_array/array/get.rs +++ b/crates/polars-ops/src/chunked_array/array/get.rs @@ -1,7 +1,6 @@ use arrow::legacy::kernels::fixed_size_list::{ sub_fixed_size_list_get, sub_fixed_size_list_get_literal, }; -use polars_core::datatypes::ArrayChunked; use polars_core::prelude::arity::binary_to_series; use super::*; diff --git a/crates/polars-ops/src/chunked_array/array/join.rs b/crates/polars-ops/src/chunked_array/array/join.rs index 3aa5f223b0e7d..69b4d5d3815b0 100644 --- a/crates/polars-ops/src/chunked_array/array/join.rs +++ b/crates/polars-ops/src/chunked_array/array/join.rs @@ -1,7 +1,5 @@ use std::fmt::Write; -use polars_core::prelude::ArrayChunked; - use super::*; fn join_literal( diff --git a/crates/polars-ops/src/chunked_array/array/min_max.rs b/crates/polars-ops/src/chunked_array/array/min_max.rs index c4857fc94ff90..c61d422e42775 100644 --- a/crates/polars-ops/src/chunked_array/array/min_max.rs +++ b/crates/polars-ops/src/chunked_array/array/min_max.rs @@ -1,4 +1,4 @@ -use arrow::array::{Array, ArrayRef, PrimitiveArray}; +use arrow::array::{Array, PrimitiveArray}; use polars_compute::min_max::MinMaxKernel; use polars_core::prelude::*; use polars_core::with_match_physical_numeric_polars_type; diff --git a/crates/polars-ops/src/chunked_array/array/sum_mean.rs b/crates/polars-ops/src/chunked_array/array/sum_mean.rs index f998e0729bb95..d27f1117fd3a1 100644 --- a/crates/polars-ops/src/chunked_array/array/sum_mean.rs +++ b/crates/polars-ops/src/chunked_array/array/sum_mean.rs @@ -1,4 +1,4 @@ -use arrow::array::{Array, ArrayRef, PrimitiveArray}; +use arrow::array::{Array, PrimitiveArray}; use arrow::bitmap::Bitmap; use arrow::legacy::utils::CustomIterTools; use arrow::types::NativeType; diff --git a/crates/polars-ops/src/chunked_array/datetime/replace_time_zone.rs b/crates/polars-ops/src/chunked_array/datetime/replace_time_zone.rs index 859fddd7e461f..bb3b5b77c26f7 100644 --- a/crates/polars-ops/src/chunked_array/datetime/replace_time_zone.rs +++ b/crates/polars-ops/src/chunked_array/datetime/replace_time_zone.rs @@ -19,8 +19,7 @@ pub fn replace_time_zone( let from_tz = parse_time_zone(from_time_zone)?; let to_tz = parse_time_zone(time_zone.unwrap_or("UTC"))?; if (from_tz == to_tz) - & ((from_tz == UTC) - | ((ambiguous.len() == 1) & (unsafe { ambiguous.get_unchecked(0) } == Some("raise")))) + & ((from_tz == UTC) | ((ambiguous.len() == 1) & (ambiguous.get(0) == Some("raise")))) { let mut out = datetime .0 @@ -39,42 +38,103 @@ pub fn replace_time_zone( TimeUnit::Microseconds => datetime_to_timestamp_us, TimeUnit::Nanoseconds => datetime_to_timestamp_ns, }; - let out = match ambiguous.len() { - 1 => match unsafe { ambiguous.get_unchecked(0) } { - Some(ambiguous) => datetime.0.try_apply(|timestamp| { - let ndt = timestamp_to_datetime(timestamp); - Ok(datetime_to_timestamp(convert_to_naive_local( - &from_tz, - &to_tz, - ndt, - Ambiguous::from_str(ambiguous)?, - )?)) - }), - _ => Ok(datetime.0.apply(|_| None)), + + let out = if ambiguous.len() == 1 && ambiguous.get(0) != Some("null") { + impl_replace_time_zone_fast( + datetime, + ambiguous.get(0), + timestamp_to_datetime, + datetime_to_timestamp, + &from_tz, + &to_tz, + ) + } else { + impl_replace_time_zone( + datetime, + ambiguous, + timestamp_to_datetime, + datetime_to_timestamp, + &from_tz, + &to_tz, + ) + }; + + let mut out = out?.into_datetime(datetime.time_unit(), time_zone.map(|x| x.to_string())); + if from_time_zone == "UTC" && ambiguous.len() == 1 && ambiguous.get(0) == Some("raise") { + // In general, the sortedness flag can't be preserved. + // To be safe, we only do so in the simplest case when we know for sure that there is no "daylight savings weirdness" going on, i.e.: + // - `from_tz` is guaranteed to not observe daylight savings time; + // - user is just passing 'raise' to 'ambiguous'. + // Both conditions above need to be satisfied. + out.set_sorted_flag(datetime.is_sorted_flag()); + } + Ok(out) +} + +/// If `ambiguous` is length-1 and not equal to "null", we can take a slightly faster path. +pub fn impl_replace_time_zone_fast( + datetime: &Logical, + ambiguous: Option<&str>, + timestamp_to_datetime: fn(i64) -> NaiveDateTime, + datetime_to_timestamp: fn(NaiveDateTime) -> i64, + from_tz: &chrono_tz::Tz, + to_tz: &chrono_tz::Tz, +) -> PolarsResult { + match ambiguous { + Some(ambiguous) => datetime.0.try_apply(|timestamp| { + let ndt = timestamp_to_datetime(timestamp); + Ok(datetime_to_timestamp( + convert_to_naive_local(from_tz, to_tz, ndt, Ambiguous::from_str(ambiguous)?)? + .expect("we didn't use Ambiguous::Null"), + )) + }), + _ => Ok(datetime.0.apply(|_| None)), + } +} + +pub fn impl_replace_time_zone( + datetime: &Logical, + ambiguous: &StringChunked, + timestamp_to_datetime: fn(i64) -> NaiveDateTime, + datetime_to_timestamp: fn(NaiveDateTime) -> i64, + from_tz: &chrono_tz::Tz, + to_tz: &chrono_tz::Tz, +) -> PolarsResult { + match ambiguous.len() { + 1 => { + debug_assert!(ambiguous.get(0) == Some("null")); + let iter = datetime.0.downcast_iter().map(|arr| { + let element_iter = arr.iter().map(|timestamp_opt| match timestamp_opt { + Some(timestamp) => { + let ndt = timestamp_to_datetime(*timestamp); + let res = convert_to_naive_local( + from_tz, + to_tz, + ndt, + Ambiguous::from_str("null")?, + )?; + Ok::<_, PolarsError>(res.map(datetime_to_timestamp)) + }, + None => Ok(None), + }); + element_iter.try_collect_arr() + }); + ChunkedArray::try_from_chunk_iter(datetime.0.name(), iter) }, _ => try_binary_elementwise(datetime, ambiguous, |timestamp_opt, ambiguous_opt| { match (timestamp_opt, ambiguous_opt) { (Some(timestamp), Some(ambiguous)) => { let ndt = timestamp_to_datetime(timestamp); - Ok(Some(datetime_to_timestamp(convert_to_naive_local( - &from_tz, - &to_tz, + Ok(convert_to_naive_local( + from_tz, + to_tz, ndt, Ambiguous::from_str(ambiguous)?, - )?))) + )? + .map(datetime_to_timestamp)) }, _ => Ok(None), } }), - }; - let mut out = out?.into_datetime(datetime.time_unit(), time_zone.map(|x| x.to_string())); - if from_time_zone == "UTC" && ambiguous.len() == 1 && ambiguous.get(0).unwrap() == "raise" { - // In general, the sortedness flag can't be preserved. - // To be safe, we only do so in the simplest case when we know for sure that there is no "daylight savings weirdness" going on, i.e.: - // - `from_tz` is guaranteed to not observe daylight savings time; - // - user is just passing 'raise' to 'ambiguous'. - // Both conditions above need to be satisfied. - out.set_sorted_flag(datetime.is_sorted_flag()); } - Ok(out) } diff --git a/crates/polars-ops/src/chunked_array/gather/chunked.rs b/crates/polars-ops/src/chunked_array/gather/chunked.rs index 6e7365d61a969..d5741961de160 100644 --- a/crates/polars-ops/src/chunked_array/gather/chunked.rs +++ b/crates/polars-ops/src/chunked_array/gather/chunked.rs @@ -1,13 +1,21 @@ +use std::borrow::Cow; +use std::fmt::Debug; + +use arrow::array::{Array, BinaryViewArray, View, INLINE_VIEW_SIZE}; +use arrow::bitmap::MutableBitmap; +use arrow::buffer::Buffer; +use arrow::legacy::trusted_len::TrustedLenPush; +use polars_core::prelude::gather::_update_gather_sorted_flag; use polars_core::prelude::*; use polars_core::series::IsSorted; use polars_core::with_match_physical_numeric_polars_type; -use polars_utils::index::ChunkId; use polars_utils::slice::GetSaferUnchecked; use crate::frame::IntoDf; pub trait DfTake: IntoDf { /// Take elements by a slice of [`ChunkId`]s. + /// /// # Safety /// Does not do any bound checks. /// `sorted` indicates if the chunks are sorted. @@ -16,17 +24,18 @@ pub trait DfTake: IntoDf { .to_df() ._apply_columns(&|s| s.take_chunked_unchecked(idx, sorted)); - DataFrame::new_no_checks(cols) + unsafe { DataFrame::new_no_checks(cols) } } /// Take elements by a slice of optional [`ChunkId`]s. + /// /// # Safety /// Does not do any bound checks. - unsafe fn _take_opt_chunked_unchecked_seq(&self, idx: &[Option]) -> DataFrame { + unsafe fn _take_opt_chunked_unchecked_seq(&self, idx: &[NullableChunkId]) -> DataFrame { let cols = self .to_df() ._apply_columns(&|s| s.take_opt_chunked_unchecked(idx)); - DataFrame::new_no_checks(cols) + unsafe { DataFrame::new_no_checks(cols) } } /// # Safety @@ -36,17 +45,19 @@ pub trait DfTake: IntoDf { .to_df() ._apply_columns_par(&|s| s.take_chunked_unchecked(idx, sorted)); - DataFrame::new_no_checks(cols) + unsafe { DataFrame::new_no_checks(cols) } } /// # Safety /// Doesn't perform any bound checks - unsafe fn _take_opt_chunked_unchecked(&self, idx: &[Option]) -> DataFrame { + /// + /// Check for null state in `ChunkId`. + unsafe fn _take_opt_chunked_unchecked(&self, idx: &[ChunkId]) -> DataFrame { let cols = self .to_df() ._apply_columns_par(&|s| s.take_opt_chunked_unchecked(idx)); - DataFrame::new_no_checks(cols) + unsafe { DataFrame::new_no_checks(cols) } } } @@ -60,12 +71,27 @@ pub trait TakeChunked { /// # Safety /// This function doesn't do any bound checks. - unsafe fn take_opt_chunked_unchecked(&self, by: &[Option]) -> Self; + unsafe fn take_opt_chunked_unchecked(&self, by: &[ChunkId]) -> Self; +} + +fn prepare_series(s: &Series) -> Cow { + let phys = if s.dtype().is_nested() { + Cow::Borrowed(s) + } else { + s.to_physical_repr() + }; + // If this is hit the cast rechunked the data and the gather will OOB + assert_eq!( + phys.chunks().len(), + s.chunks().len(), + "implementation error" + ); + phys } impl TakeChunked for Series { unsafe fn take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Self { - let phys = self.to_physical_repr(); + let phys = prepare_series(self); use DataType::*; let out = match phys.dtype() { dt if dt.is_numeric() => { @@ -80,11 +106,14 @@ impl TakeChunked for Series { }, Binary => { let ca = phys.binary().unwrap(); - ca.take_chunked_unchecked(by, sorted).into_series() + let out = take_unchecked_binview(ca, by, sorted); + out.into_series() }, String => { let ca = phys.str().unwrap(); - ca.take_chunked_unchecked(by, sorted).into_series() + let ca = ca.as_binary(); + let out = take_unchecked_binview(&ca, by, sorted); + out.to_string().into_series() }, List(_) => { let ca = phys.list().unwrap(); @@ -116,8 +145,9 @@ impl TakeChunked for Series { unsafe { out.cast_unchecked(self.dtype()).unwrap() } } - unsafe fn take_opt_chunked_unchecked(&self, by: &[Option]) -> Self { - let phys = self.to_physical_repr(); + /// Take function that checks of null state in `ChunkIdx`. + unsafe fn take_opt_chunked_unchecked(&self, by: &[NullableChunkId]) -> Self { + let phys = prepare_series(self); use DataType::*; let out = match phys.dtype() { dt if dt.is_numeric() => { @@ -132,11 +162,14 @@ impl TakeChunked for Series { }, Binary => { let ca = phys.binary().unwrap(); - ca.take_opt_chunked_unchecked(by).into_series() + let out = take_unchecked_binview_opt(ca, by); + out.into_series() }, String => { let ca = phys.str().unwrap(); - ca.take_opt_chunked_unchecked(by).into_series() + let ca = ca.as_binary(); + let out = take_unchecked_binview_opt(&ca, by); + out.to_string().into_series() }, List(_) => { let ca = phys.list().unwrap(); @@ -172,6 +205,7 @@ impl TakeChunked for Series { impl TakeChunked for ChunkedArray where T: PolarsDataType, + T::Array: Debug, { unsafe fn take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Self { let arrow_dtype = self.dtype().to_arrow(true); @@ -179,6 +213,10 @@ where let mut out = if let Some(iter) = self.downcast_slices() { let targets = iter.collect::>(); let iter = by.iter().map(|chunk_id| { + debug_assert!( + !chunk_id.is_null(), + "null chunks should not hit this branch" + ); let (chunk_idx, array_idx) = chunk_id.extract(); let vals = targets.get_unchecked_release(chunk_idx as usize); vals.get_unchecked_release(array_idx as usize).clone() @@ -189,6 +227,10 @@ where } else { let targets = self.downcast_iter().collect::>(); let iter = by.iter().map(|chunk_id| { + debug_assert!( + !chunk_id.is_null(), + "null chunks should not hit this branch" + ); let (chunk_idx, array_idx) = chunk_id.extract(); let vals = targets.get_unchecked_release(chunk_idx as usize); vals.get_unchecked(array_idx as usize) @@ -196,11 +238,13 @@ where let arr = iter.collect_arr_trusted_with_dtype(arrow_dtype); ChunkedArray::with_chunk(self.name(), arr) }; - out.set_sorted_flag(sorted); + let sorted_flag = _update_gather_sorted_flag(self.is_sorted_flag(), sorted); + out.set_sorted_flag(sorted_flag); out } - unsafe fn take_opt_chunked_unchecked(&self, by: &[Option]) -> Self { + // Take function that checks of null state in `ChunkIdx`. + unsafe fn take_opt_chunked_unchecked(&self, by: &[NullableChunkId]) -> Self { let arrow_dtype = self.dtype().to_arrow(true); if let Some(iter) = self.downcast_slices() { @@ -208,11 +252,13 @@ where let arr = by .iter() .map(|chunk_id| { - chunk_id.map(|chunk_id| { + if chunk_id.is_null() { + None + } else { let (chunk_idx, array_idx) = chunk_id.extract(); let vals = *targets.get_unchecked_release(chunk_idx as usize); - vals.get_unchecked_release(array_idx as usize).clone() - }) + Some(vals.get_unchecked_release(array_idx as usize).clone()) + } }) .collect_arr_trusted_with_dtype(arrow_dtype); @@ -222,11 +268,13 @@ where let arr = by .iter() .map(|chunk_id| { - chunk_id.and_then(|chunk_id| { + if chunk_id.is_null() { + None + } else { let (chunk_idx, array_idx) = chunk_id.extract(); let vals = *targets.get_unchecked_release(chunk_idx as usize); vals.get_unchecked(array_idx as usize) - }) + } }) .collect_arr_trusted_with_dtype(arrow_dtype); @@ -252,20 +300,260 @@ unsafe fn take_unchecked_object(s: &Series, by: &[ChunkId], _sorted: IsSorted) - } #[cfg(feature = "object")] -unsafe fn take_opt_unchecked_object(s: &Series, by: &[Option]) -> Series { +unsafe fn take_opt_unchecked_object(s: &Series, by: &[NullableChunkId]) -> Series { let DataType::Object(_, reg) = s.dtype() else { unreachable!() }; let reg = reg.as_ref().unwrap(); let mut builder = (*reg.builder_constructor)(s.name(), by.len()); - by.iter().for_each(|chunk_id| match chunk_id { - None => builder.append_null(), - Some(chunk_id) => { + by.iter().for_each(|chunk_id| { + if chunk_id.is_null() { + builder.append_null() + } else { let (chunk_idx, array_idx) = chunk_id.extract(); let object = s.get_object_chunked_unchecked(chunk_idx as usize, array_idx as usize); builder.append_option(object.map(|v| v.as_any())) - }, + } }); builder.to_series() } + +#[allow(clippy::unnecessary_cast)] +#[inline(always)] +fn rewrite_view(mut view: View, chunk_idx: IdxSize) -> View { + let chunk_idx = chunk_idx as u32; + let offset = [0, chunk_idx][(view.length > INLINE_VIEW_SIZE) as usize]; + view.buffer_idx += offset; + view +} + +#[allow(clippy::unnecessary_cast)] +unsafe fn take_unchecked_binview( + ca: &BinaryChunked, + by: &[ChunkId], + sorted: IsSorted, +) -> BinaryChunked { + let views = ca + .downcast_iter() + .map(|arr| arr.views().as_slice()) + .collect::>(); + let buffers: Arc<[Buffer]> = ca + .downcast_iter() + .flat_map(|arr| arr.data_buffers().as_ref()) + .cloned() + .collect(); + + let (views, validity) = if ca.null_count() == 0 { + let views = by + .iter() + .map(|chunk_id| { + let (chunk_idx, array_idx) = chunk_id.extract(); + let array_idx = array_idx as usize; + + let target = *views.get_unchecked_release(chunk_idx as usize); + let view = *target.get_unchecked_release(array_idx); + + rewrite_view(view, chunk_idx) + }) + .collect::>(); + + (views, None) + } else { + let targets = ca.downcast_iter().collect::>(); + + let mut mut_views = Vec::with_capacity(by.len()); + let mut validity = MutableBitmap::with_capacity(by.len()); + + for id in by.iter() { + let (chunk_idx, array_idx) = id.extract(); + let array_idx = array_idx as usize; + + let target = *targets.get_unchecked_release(chunk_idx as usize); + if target.is_null_unchecked(array_idx) { + mut_views.push_unchecked(View::default()); + validity.push_unchecked(false) + } else { + let target = *views.get_unchecked_release(chunk_idx as usize); + let view = *target.get_unchecked_release(array_idx); + let view = rewrite_view(view, chunk_idx); + mut_views.push_unchecked(view); + validity.push_unchecked(true) + } + } + + (mut_views, Some(validity.freeze())) + }; + + let arr = BinaryViewArray::new_unchecked_unknown_md( + ArrowDataType::BinaryView, + views.into(), + buffers, + validity, + None, + ) + .maybe_gc(); + + let mut out = BinaryChunked::with_chunk(ca.name(), arr); + let sorted_flag = _update_gather_sorted_flag(ca.is_sorted_flag(), sorted); + out.set_sorted_flag(sorted_flag); + out +} + +unsafe fn take_unchecked_binview_opt(ca: &BinaryChunked, by: &[NullableChunkId]) -> BinaryChunked { + let views = ca + .downcast_iter() + .map(|arr| arr.views().as_slice()) + .collect::>(); + let buffers: Arc<[Buffer]> = ca + .downcast_iter() + .flat_map(|arr| arr.data_buffers().as_ref()) + .cloned() + .collect(); + + let targets = ca.downcast_iter().collect::>(); + + let mut mut_views = Vec::with_capacity(by.len()); + let mut validity = MutableBitmap::with_capacity(by.len()); + + let (views, validity) = if ca.null_count() == 0 { + for id in by.iter() { + if id.is_null() { + mut_views.push_unchecked(View::default()); + validity.push_unchecked(false) + } else { + let (chunk_idx, array_idx) = id.extract(); + let array_idx = array_idx as usize; + + let target = *views.get_unchecked_release(chunk_idx as usize); + let view = *target.get_unchecked_release(array_idx); + let view = rewrite_view(view, chunk_idx); + + mut_views.push_unchecked(view); + validity.push_unchecked(true) + } + } + (mut_views, Some(validity.freeze())) + } else { + for id in by.iter() { + if id.is_null() { + mut_views.push_unchecked(View::default()); + validity.push_unchecked(false) + } else { + let (chunk_idx, array_idx) = id.extract(); + let array_idx = array_idx as usize; + + let target = *targets.get_unchecked_release(chunk_idx as usize); + if target.is_null_unchecked(array_idx) { + mut_views.push_unchecked(View::default()); + validity.push_unchecked(false) + } else { + let target = *views.get_unchecked_release(chunk_idx as usize); + let view = *target.get_unchecked_release(array_idx); + let view = rewrite_view(view, chunk_idx); + mut_views.push_unchecked(view); + validity.push_unchecked(true); + } + } + } + + (mut_views, Some(validity.freeze())) + }; + + let arr = BinaryViewArray::new_unchecked_unknown_md( + ArrowDataType::BinaryView, + views.into(), + buffers, + validity, + None, + ) + .maybe_gc(); + + BinaryChunked::with_chunk(ca.name(), arr) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_binview_chunked_gather() { + unsafe { + // # Series without nulls; + let mut s_1 = Series::new( + "a", + &["1 loooooooooooong string 1", "2 loooooooooooong string 2"], + ); + let s_2 = Series::new( + "a", + &[ + "11 loooooooooooong string 11", + "22 loooooooooooong string 22", + ], + ); + s_1.append(&s_2).unwrap(); + + assert_eq!(s_1.n_chunks(), 2); + + // ## Ids without nulls; + let by = [ + ChunkId::store(0, 0), + ChunkId::store(0, 1), + ChunkId::store(1, 1), + ChunkId::store(1, 0), + ]; + + let out = s_1.take_chunked_unchecked(&by, IsSorted::Not); + let idx = IdxCa::new("", [0, 1, 3, 2]); + let expected = s_1.rechunk().take(&idx).unwrap(); + assert!(out.equals(&expected)); + + // ## Ids with nulls; + let by: [ChunkId; 4] = [ + ChunkId::null(), + ChunkId::store(0, 1), + ChunkId::store(1, 1), + ChunkId::store(1, 0), + ]; + let out = s_1.take_opt_chunked_unchecked(&by); + + let idx = IdxCa::new("", [None, Some(1), Some(3), Some(2)]); + let expected = s_1.rechunk().take(&idx).unwrap(); + assert!(out.equals_missing(&expected)); + + // # Series with nulls; + let mut s_1 = Series::new( + "a", + &["1 loooooooooooong string 1", "2 loooooooooooong string 2"], + ); + let s_2 = Series::new("a", &[Some("11 loooooooooooong string 11"), None]); + s_1.append(&s_2).unwrap(); + + // ## Ids without nulls; + let by = [ + ChunkId::store(0, 0), + ChunkId::store(0, 1), + ChunkId::store(1, 1), + ChunkId::store(1, 0), + ]; + + let out = s_1.take_chunked_unchecked(&by, IsSorted::Not); + let idx = IdxCa::new("", [0, 1, 3, 2]); + let expected = s_1.rechunk().take(&idx).unwrap(); + assert!(out.equals_missing(&expected)); + + // ## Ids with nulls; + let by: [ChunkId; 4] = [ + ChunkId::null(), + ChunkId::store(0, 1), + ChunkId::store(1, 1), + ChunkId::store(1, 0), + ]; + let out = s_1.take_opt_chunked_unchecked(&by); + + let idx = IdxCa::new("", [None, Some(1), Some(3), Some(2)]); + let expected = s_1.rechunk().take(&idx).unwrap(); + assert!(out.equals_missing(&expected)); + } + } +} diff --git a/crates/polars-ops/src/chunked_array/gather_skip_nulls.rs b/crates/polars-ops/src/chunked_array/gather_skip_nulls.rs index e656fcded1cb0..ba1427e6f6fd6 100644 --- a/crates/polars-ops/src/chunked_array/gather_skip_nulls.rs +++ b/crates/polars-ops/src/chunked_array/gather_skip_nulls.rs @@ -155,7 +155,6 @@ mod test { use rand::distributions::uniform::SampleUniform; use rand::prelude::*; - use rand::rngs::SmallRng; use super::*; diff --git a/crates/polars-ops/src/chunked_array/hist.rs b/crates/polars-ops/src/chunked_array/hist.rs index 9c8653f2c1ff0..5833a1bc784df 100644 --- a/crates/polars-ops/src/chunked_array/hist.rs +++ b/crates/polars-ops/src/chunked_array/hist.rs @@ -1,16 +1,10 @@ use std::fmt::Write; -use arrow::legacy::index::IdxSize; use num_traits::ToPrimitive; -use polars_core::datatypes::PolarsNumericType; -use polars_core::prelude::{ - ChunkCast, ChunkSort, ChunkedArray, DataType, StringChunkedBuilder, StructChunked, UInt32Type, - *, -}; +use polars_core::prelude::*; use polars_core::with_match_physical_numeric_polars_type; -use polars_error::PolarsResult; use polars_utils::float::IsFloat; -use polars_utils::total_ord::TotalOrdWrap; +use polars_utils::total_ord::ToTotalOrd; fn compute_hist( ca: &ChunkedArray, @@ -26,7 +20,7 @@ where let (breaks, count) = if let Some(bins) = bins { let mut breaks = Vec::with_capacity(bins.len() + 1); breaks.extend_from_slice(bins); - breaks.sort_unstable_by_key(|k| TotalOrdWrap(*k)); + breaks.sort_unstable_by_key(|k| k.to_total_ord()); breaks.push(f64::INFINITY); let sorted = ca.sort(false); diff --git a/crates/polars-ops/src/chunked_array/list/dispersion.rs b/crates/polars-ops/src/chunked_array/list/dispersion.rs index 2738ae869425d..3d47520c1d924 100644 --- a/crates/polars-ops/src/chunked_array/list/dispersion.rs +++ b/crates/polars-ops/src/chunked_array/list/dispersion.rs @@ -1,5 +1,3 @@ -use polars_core::datatypes::ListChunked; - use super::*; pub(super) fn median_with_nulls(ca: &ListChunked) -> Series { diff --git a/crates/polars-ops/src/chunked_array/list/hash.rs b/crates/polars-ops/src/chunked_array/list/hash.rs index 5931753f6ebf2..fe00dcdceeb65 100644 --- a/crates/polars-ops/src/chunked_array/list/hash.rs +++ b/crates/polars-ops/src/chunked_array/list/hash.rs @@ -1,17 +1,18 @@ use std::hash::Hash; use polars_core::export::_boost_hash_combine; -use polars_core::export::ahash::{self}; use polars_core::export::rayon::prelude::*; use polars_core::utils::NoNull; -use polars_core::POOL; +use polars_core::{with_match_physical_float_polars_type, POOL}; +use polars_utils::total_ord::{ToTotalOrd, TotalHash}; use super::*; fn hash_agg(ca: &ChunkedArray, random_state: &ahash::RandomState) -> u64 where - T: PolarsIntegerType, - T::Native: Hash, + T: PolarsNumericType, + T::Native: TotalHash + ToTotalOrd, + ::TotalOrdItem: Hash, { // Note that we don't use the no null branch! This can break in unexpected ways. // for instance with threading we split an array in n_threads, this may lead to @@ -30,7 +31,7 @@ where for opt_v in arr.iter() { match opt_v { Some(v) => { - let r = random_state.hash_one(v); + let r = random_state.hash_one(v.to_total_ord()); hash_agg = _boost_hash_combine(hash_agg, r); }, None => { @@ -60,7 +61,12 @@ pub(crate) fn hash(ca: &mut ListChunked, build_hasher: ahash::RandomState) -> UI .map(|opt_s: Option| match opt_s { None => null_hash, Some(s) => { - if s.bit_repr_is_large() { + if s.dtype().is_float() { + with_match_physical_float_polars_type!(s.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); + hash_agg(ca, &build_hasher) + }) + } else if s.bit_repr_is_large() { let ca = s.bit_repr_large(); hash_agg(&ca, &build_hasher) } else { diff --git a/crates/polars-ops/src/chunked_array/list/min_max.rs b/crates/polars-ops/src/chunked_array/list/min_max.rs index 51db2f079b084..dd043110be2ee 100644 --- a/crates/polars-ops/src/chunked_array/list/min_max.rs +++ b/crates/polars-ops/src/chunked_array/list/min_max.rs @@ -1,6 +1,5 @@ -use arrow::array::{Array, ArrayRef, PrimitiveArray}; +use arrow::array::{Array, PrimitiveArray}; use arrow::bitmap::Bitmap; -use arrow::legacy::array::PolarsArray; use arrow::types::NativeType; use polars_compute::min_max::MinMaxKernel; use polars_core::prelude::*; diff --git a/crates/polars-ops/src/chunked_array/list/namespace.rs b/crates/polars-ops/src/chunked_array/list/namespace.rs index 201189fc86add..38ca7732c40c0 100644 --- a/crates/polars-ops/src/chunked_array/list/namespace.rs +++ b/crates/polars-ops/src/chunked_array/list/namespace.rs @@ -1,4 +1,3 @@ -use std::convert::TryFrom; use std::fmt::Write; use arrow::array::ValueSize; diff --git a/crates/polars-ops/src/chunked_array/list/sets.rs b/crates/polars-ops/src/chunked_array/list/sets.rs index 535b16d85c2ef..2eec5efbe6fc4 100644 --- a/crates/polars-ops/src/chunked_array/list/sets.rs +++ b/crates/polars-ops/src/chunked_array/list/sets.rs @@ -11,7 +11,7 @@ use arrow::offset::OffsetsBuffer; use arrow::types::NativeType; use polars_core::prelude::*; use polars_core::with_match_physical_numeric_type; -use polars_utils::total_ord::{TotalEq, TotalHash, TotalOrdWrap}; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash, TotalOrdWrap}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -30,12 +30,12 @@ where } } -impl MaterializeValues>> for MutablePrimitiveArray +impl MaterializeValues>> for MutablePrimitiveArray where T: NativeType, { - fn extend_buf>>>(&mut self, values: I) -> usize { - self.extend(values); + fn extend_buf>>>(&mut self, values: I) -> usize { + self.extend(values.map(|x| x.0)); self.len() } } @@ -102,8 +102,10 @@ where } } -fn copied_wrapper_opt(v: Option<&T>) -> Option> { - v.copied().map(TotalOrdWrap) +fn copied_wrapper_opt( + v: Option<&T>, +) -> as ToTotalOrd>::TotalOrdItem { + v.copied().to_total_ord() } #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] @@ -136,13 +138,14 @@ fn primitive( validity: Option, ) -> PolarsResult> where - T: NativeType + TotalHash + Copy + TotalEq, + T: NativeType + TotalHash + TotalEq + Copy + ToTotalOrd, + as ToTotalOrd>::TotalOrdItem: Hash + Eq + Copy, { let broadcast_lhs = offsets_a.len() == 2; let broadcast_rhs = offsets_b.len() == 2; let mut set = Default::default(); - let mut set2: PlIndexSet>> = Default::default(); + let mut set2: PlIndexSet< as ToTotalOrd>::TotalOrdItem> = Default::default(); let mut values_out = MutablePrimitiveArray::with_capacity(std::cmp::max( *offsets_a.last().unwrap(), @@ -418,6 +421,7 @@ pub fn list_set_operation( b.prune_empty_chunks(); // Make categoricals compatible + #[cfg(feature = "dtype-categorical")] if let (DataType::Categorical(_, _), DataType::Categorical(_, _)) = (&a.inner_dtype(), &b.inner_dtype()) { diff --git a/crates/polars-ops/src/chunked_array/list/sum_mean.rs b/crates/polars-ops/src/chunked_array/list/sum_mean.rs index fe48e397459ba..e3a14e2340f78 100644 --- a/crates/polars-ops/src/chunked_array/list/sum_mean.rs +++ b/crates/polars-ops/src/chunked_array/list/sum_mean.rs @@ -1,10 +1,9 @@ use std::ops::Div; -use arrow::array::{Array, ArrayRef, PrimitiveArray}; +use arrow::array::{Array, PrimitiveArray}; use arrow::bitmap::Bitmap; use arrow::legacy::utils::CustomIterTools; use arrow::types::NativeType; -use polars_core::datatypes::ListChunked; use polars_core::export::num::{NumCast, ToPrimitive}; use polars_utils::unwrap::UnwrapUncheckedRelease; diff --git a/crates/polars-ops/src/chunked_array/mode.rs b/crates/polars-ops/src/chunked_array/mode.rs index 62596f38da2f2..26b728306c5ee 100644 --- a/crates/polars-ops/src/chunked_array/mode.rs +++ b/crates/polars-ops/src/chunked_array/mode.rs @@ -1,5 +1,4 @@ use arrow::legacy::utils::CustomIterTools; -use polars_core::frame::group_by::IntoGroupsProxy; use polars_core::prelude::*; use polars_core::{with_match_physical_integer_polars_type, POOL}; diff --git a/crates/polars-ops/src/chunked_array/repeat_by.rs b/crates/polars-ops/src/chunked_array/repeat_by.rs index bd844501f94de..bdba858d5719d 100644 --- a/crates/polars-ops/src/chunked_array/repeat_by.rs +++ b/crates/polars-ops/src/chunked_array/repeat_by.rs @@ -1,5 +1,4 @@ use arrow::array::ListArray; -use arrow::legacy::array::ListFromIter; use polars_core::prelude::*; use polars_core::with_match_physical_numeric_polars_type; diff --git a/crates/polars-ops/src/chunked_array/strings/concat.rs b/crates/polars-ops/src/chunked_array/strings/concat.rs index 5d33cf9c91bab..c290b076cac4e 100644 --- a/crates/polars-ops/src/chunked_array/strings/concat.rs +++ b/crates/polars-ops/src/chunked_array/strings/concat.rs @@ -1,6 +1,5 @@ use arrow::array::{Utf8Array, ValueSize}; use arrow::compute::cast::utf8_to_utf8view; -use arrow::legacy::array::default_arrays::FromDataUtf8; use polars_core::prelude::*; // Vertically concatenate all strings in a StringChunked. diff --git a/crates/polars-ops/src/frame/join/args.rs b/crates/polars-ops/src/frame/join/args.rs index 4246ead4b6e11..b0008b7afb7ce 100644 --- a/crates/polars-ops/src/frame/join/args.rs +++ b/crates/polars-ops/src/frame/join/args.rs @@ -7,22 +7,17 @@ pub type InnerJoinIds = (JoinIds, JoinIds); #[cfg(feature = "chunked_ids")] pub(super) type ChunkJoinIds = Either, Vec>; #[cfg(feature = "chunked_ids")] -pub type ChunkJoinOptIds = Either>, Vec>>; +pub type ChunkJoinOptIds = Either, Vec>; #[cfg(not(feature = "chunked_ids"))] -pub type ChunkJoinOptIds = Vec>; +pub type ChunkJoinOptIds = Vec; #[cfg(not(feature = "chunked_ids"))] pub type ChunkJoinIds = Vec; -#[cfg(feature = "chunked_ids")] -use polars_utils::index::ChunkId; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -#[cfg(feature = "asof_join")] -use super::asof::AsOfOptions; - #[derive(Clone, PartialEq, Eq, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct JoinArgs { diff --git a/crates/polars-ops/src/frame/join/asof/groups.rs b/crates/polars-ops/src/frame/join/asof/groups.rs index f64f7d1009844..3fae9258c0c78 100644 --- a/crates/polars-ops/src/frame/join/asof/groups.rs +++ b/crates/polars-ops/src/frame/join/asof/groups.rs @@ -3,16 +3,17 @@ use std::hash::Hash; use ahash::RandomState; use num_traits::Zero; use polars_core::hashing::{_df_rows_to_hashes_threaded_vertical, _HASHMAP_INIT_SIZE}; +use polars_core::utils::flatten::flatten_nullable; use polars_core::utils::{split_ca, split_df}; -use polars_core::POOL; +use polars_core::{with_match_physical_float_polars_type, POOL}; use polars_utils::abs_diff::AbsDiff; use polars_utils::hashing::{hash_to_partition, DirtyHash}; use polars_utils::nulls::IsNull; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; use rayon::prelude::*; use smartstring::alias::String as SmartString; use super::*; -use crate::frame::IntoDf; fn compute_len_offsets>(iter: I) -> Vec { let mut cumlen = 0; @@ -25,6 +26,14 @@ fn compute_len_offsets>(iter: I) -> Vec { .collect() } +#[inline(always)] +fn materialize_nullable(idx: Option) -> NullableIdxSize { + match idx { + Some(t) => NullableIdxSize::from(t), + None => NullableIdxSize::null(), + } +} + fn asof_in_group<'a, T, A, F>( left_val: T::Physical<'a>, right_val_arr: &'a T::Array, @@ -67,11 +76,12 @@ fn asof_join_by_numeric( left_asof: &ChunkedArray, right_asof: &ChunkedArray, filter: F, -) -> PolarsResult>> +) -> PolarsResult where T: PolarsDataType, S: PolarsNumericType, - S::Native: Hash + Eq + DirtyHash + IsNull, + S::Native: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash + IsNull, A: for<'a> AsofJoinState>, F: Sync + for<'a> Fn(T::Physical<'a>, T::Physical<'a>) -> bool, { @@ -95,49 +105,48 @@ where let n_tables = hash_tbls.len(); // Now we probe the right hand side for each left hand side. - Ok(POOL - .install(|| { - split_by_left - .into_par_iter() - .zip(offsets) - .flat_map(|(by_left, offset)| { - let mut results = Vec::with_capacity(by_left.len()); - let mut group_states: PlHashMap = - PlHashMap::with_capacity(_HASHMAP_INIT_SIZE); - - let by_left_chunk = by_left.downcast_iter().next().unwrap(); - for (rel_idx_left, opt_by_left_k) in by_left_chunk.iter().enumerate() { - let Some(by_left_k) = opt_by_left_k else { - results.push(None); - continue; - }; - let idx_left = (rel_idx_left + offset) as IdxSize; - let Some(left_val) = left_val_arr.get(idx_left as usize) else { - results.push(None); - continue; - }; - - let group_probe_table = unsafe { - hash_tbls - .get_unchecked(hash_to_partition(by_left_k.dirty_hash(), n_tables)) - }; - let Some(right_grp_idxs) = group_probe_table.get(by_left_k) else { - results.push(None); - continue; - }; - - results.push(asof_in_group::( - left_val, - right_val_arr, - right_grp_idxs.as_slice(), - &mut group_states, - &filter, - )); - } - results - }) - }) - .collect()) + let out = split_by_left + .into_par_iter() + .zip(offsets) + .map(|(by_left, offset)| { + let mut results = Vec::with_capacity(by_left.len()); + let mut group_states: PlHashMap = + PlHashMap::with_capacity(_HASHMAP_INIT_SIZE); + + let by_left_chunk = by_left.downcast_iter().next().unwrap(); + for (rel_idx_left, opt_by_left_k) in by_left_chunk.iter().enumerate() { + let Some(by_left_k) = opt_by_left_k else { + results.push(NullableIdxSize::null()); + continue; + }; + let by_left_k = by_left_k.to_total_ord(); + let idx_left = (rel_idx_left + offset) as IdxSize; + let Some(left_val) = left_val_arr.get(idx_left as usize) else { + results.push(NullableIdxSize::null()); + continue; + }; + + let group_probe_table = unsafe { + hash_tbls.get_unchecked(hash_to_partition(by_left_k.dirty_hash(), n_tables)) + }; + let Some(right_grp_idxs) = group_probe_table.get(&by_left_k) else { + results.push(NullableIdxSize::null()); + continue; + }; + let id = asof_in_group::( + left_val, + right_val_arr, + right_grp_idxs.as_slice(), + &mut group_states, + &filter, + ); + results.push(materialize_nullable(id)); + } + results + }); + + let bufs = POOL.install(|| out.collect::>()); + Ok(flatten_nullable(&bufs)) } fn asof_join_by_binary( @@ -146,7 +155,7 @@ fn asof_join_by_binary( left_asof: &ChunkedArray, right_asof: &ChunkedArray, filter: F, -) -> Vec> +) -> IdxArr where T: PolarsDataType, A: for<'a> AsofJoinState>, @@ -169,42 +178,41 @@ where let n_tables = hash_tbls.len(); // Now we probe the right hand side for each left hand side. - POOL.install(|| { - prep_by_left - .into_par_iter() - .zip(offsets) - .flat_map(|(by_left, offset)| { - let mut results = Vec::with_capacity(by_left.len()); - let mut group_states: PlHashMap<_, A> = - PlHashMap::with_capacity(_HASHMAP_INIT_SIZE); - - for (rel_idx_left, by_left_k) in by_left.iter().enumerate() { - let idx_left = (rel_idx_left + offset) as IdxSize; - let Some(left_val) = left_val_arr.get(idx_left as usize) else { - results.push(None); - continue; - }; - - let group_probe_table = unsafe { - hash_tbls.get_unchecked(hash_to_partition(by_left_k.dirty_hash(), n_tables)) - }; - let Some(right_grp_idxs) = group_probe_table.get(by_left_k) else { - results.push(None); - continue; - }; - - results.push(asof_in_group::( - left_val, - right_val_arr, - right_grp_idxs.as_slice(), - &mut group_states, - &filter, - )); - } - results - }) - .collect() - }) + let iter = prep_by_left + .into_par_iter() + .zip(offsets) + .map(|(by_left, offset)| { + let mut results = Vec::with_capacity(by_left.len()); + let mut group_states: PlHashMap<_, A> = PlHashMap::with_capacity(_HASHMAP_INIT_SIZE); + + for (rel_idx_left, by_left_k) in by_left.iter().enumerate() { + let idx_left = (rel_idx_left + offset) as IdxSize; + let Some(left_val) = left_val_arr.get(idx_left as usize) else { + results.push(NullableIdxSize::null()); + continue; + }; + + let group_probe_table = unsafe { + hash_tbls.get_unchecked(hash_to_partition(by_left_k.dirty_hash(), n_tables)) + }; + let Some(right_grp_idxs) = group_probe_table.get(by_left_k) else { + results.push(NullableIdxSize::null()); + continue; + }; + let id = asof_in_group::( + left_val, + right_val_arr, + right_grp_idxs.as_slice(), + &mut group_states, + &filter, + ); + + results.push(materialize_nullable(id)); + } + results + }); + let bufs = POOL.install(|| iter.collect::>()); + flatten_nullable(&bufs) } fn asof_join_by_multiple( @@ -213,7 +221,7 @@ fn asof_join_by_multiple( left_asof: &ChunkedArray, right_asof: &ChunkedArray, filter: F, -) -> Vec> +) -> IdxArr where T: PolarsDataType, A: for<'a> AsofJoinState>, @@ -239,61 +247,60 @@ where let n_tables = hash_tbls.len(); // Now we probe the right hand side for each left hand side. - POOL.install(|| { - probe_hashes - .into_par_iter() - .zip(offsets) - .flat_map(|(hash_by_left, offset)| { - let mut results = Vec::with_capacity(hash_by_left.len()); - let mut group_states: PlHashMap<_, A> = - PlHashMap::with_capacity(_HASHMAP_INIT_SIZE); - - let mut ctr = 0; - for by_left_view in hash_by_left.data_views() { - for h_left in by_left_view.iter().copied() { - let idx_left = offset + ctr; - ctr += 1; - let opt_left_val = left_val_arr.get(idx_left); - - let Some(left_val) = opt_left_val else { - results.push(None); - continue; - }; - - let group_probe_table = - unsafe { hash_tbls.get_unchecked(hash_to_partition(h_left, n_tables)) }; - - let entry = group_probe_table.raw_entry().from_hash(h_left, |idx_hash| { - let idx_right = idx_hash.idx; - // SAFETY: indices in a join operation are always in bounds. - unsafe { - mk::compare_df_rows2( - by_left, - by_right, - idx_left, - idx_right as usize, - false, - ) - } - }); - let Some((_, right_grp_idxs)) = entry else { - results.push(None); - continue; - }; - - results.push(asof_in_group::( - left_val, - right_val_arr, - &right_grp_idxs[..], - &mut group_states, - &filter, - )); - } + let iter = probe_hashes + .into_par_iter() + .zip(offsets) + .map(|(hash_by_left, offset)| { + let mut results = Vec::with_capacity(hash_by_left.len()); + let mut group_states: PlHashMap<_, A> = PlHashMap::with_capacity(_HASHMAP_INIT_SIZE); + + let mut ctr = 0; + for by_left_view in hash_by_left.data_views() { + for h_left in by_left_view.iter().copied() { + let idx_left = offset + ctr; + ctr += 1; + let opt_left_val = left_val_arr.get(idx_left); + + let Some(left_val) = opt_left_val else { + results.push(NullableIdxSize::null()); + continue; + }; + + let group_probe_table = + unsafe { hash_tbls.get_unchecked(hash_to_partition(h_left, n_tables)) }; + + let entry = group_probe_table.raw_entry().from_hash(h_left, |idx_hash| { + let idx_right = idx_hash.idx; + // SAFETY: indices in a join operation are always in bounds. + unsafe { + mk::compare_df_rows2( + by_left, + by_right, + idx_left, + idx_right as usize, + false, + ) + } + }); + let Some((_, right_grp_idxs)) = entry else { + results.push(NullableIdxSize::null()); + continue; + }; + let id = asof_in_group::( + left_val, + right_val_arr, + &right_grp_idxs[..], + &mut group_states, + &filter, + ); + + results.push(materialize_nullable(id)); } - results - }) - .collect() - }) + } + results + }); + let bufs = POOL.install(|| iter.collect::>()); + flatten_nullable(&bufs) } #[allow(clippy::too_many_arguments)] @@ -303,7 +310,7 @@ fn dispatch_join_by_type( left_by: &mut DataFrame, right_by: &mut DataFrame, filter: F, -) -> PolarsResult>> +) -> PolarsResult where T: PolarsDataType, A: for<'a> AsofJoinState>, @@ -329,7 +336,15 @@ where asof_join_by_binary::(left_by, right_by, left_asof, right_asof, filter) }, _ => { - if left_by_s.bit_repr_is_large() { + if left_by_s.dtype().is_float() { + with_match_physical_float_polars_type!(left_by_s.dtype(), |$T| { + let left_by: &ChunkedArray<$T> = left_by_s.as_ref().as_ref().as_ref(); + let right_by: &ChunkedArray<$T> = right_by_s.as_ref().as_ref().as_ref(); + asof_join_by_numeric::( + left_by, right_by, left_asof, right_asof, filter, + )? + }) + } else if left_by_s.bit_repr_is_large() { let left_by = left_by_s.bit_repr_large(); let right_by = right_by_s.bit_repr_large(); asof_join_by_numeric::( @@ -364,7 +379,7 @@ fn dispatch_join_strategy( left_by: &mut DataFrame, right_by: &mut DataFrame, strategy: AsofStrategy, -) -> PolarsResult>> +) -> PolarsResult where for<'a> T::Physical<'a>: PartialOrd, { @@ -390,7 +405,7 @@ fn dispatch_join_strategy_numeric( right_by: &mut DataFrame, strategy: AsofStrategy, tolerance: Option>, -) -> PolarsResult>> { +) -> PolarsResult { let right_ca = left_asof.unpack_series_matching_type(right_asof)?; if let Some(tol) = tolerance { @@ -432,7 +447,7 @@ fn dispatch_join_type( right_by: &mut DataFrame, strategy: AsofStrategy, tolerance: Option>, -) -> PolarsResult>> { +) -> PolarsResult { match left_asof.dtype() { DataType::Int64 => { let ca = left_asof.i64().unwrap(); @@ -560,15 +575,13 @@ pub trait AsofJoinBy: IntoDf { .filter(|s| !drop_these.contains(&s.name())) .cloned() .collect(); - let proj_other_df = DataFrame::new_no_checks(cols); + let proj_other_df = unsafe { DataFrame::new_no_checks(cols) }; let left = self_df.clone(); - let right_join_tuples = &*right_join_tuples; // SAFETY: join tuples are in bounds. - let right_df = unsafe { - proj_other_df.take_unchecked(&right_join_tuples.iter().copied().collect_ca("")) - }; + let right_df = + unsafe { proj_other_df.take_unchecked(&IdxCa::with_chunk("", right_join_tuples)) }; _finish_join(left, right_df, suffix) } diff --git a/crates/polars-ops/src/frame/join/cross_join.rs b/crates/polars-ops/src/frame/join/cross_join.rs index 73a34f0f613cb..1e1b1bcba4978 100644 --- a/crates/polars-ops/src/frame/join/cross_join.rs +++ b/crates/polars-ops/src/frame/join/cross_join.rs @@ -1,6 +1,4 @@ -use polars_core::series::IsSorted; -use polars_core::utils::{concat_df_unchecked, slice_offsets, CustomIterTools, NoNull}; -use polars_core::POOL; +use polars_core::utils::{concat_df_unchecked, CustomIterTools, NoNull}; use smartstring::alias::String as SmartString; use super::*; @@ -120,7 +118,7 @@ pub trait CrossJoin: IntoDf { Ok(l_df) } - /// Creates the cartesian product from both frames, preserves the order of the left keys. + /// Creates the Cartesian product from both frames, preserves the order of the left keys. fn cross_join( &self, other: &DataFrame, diff --git a/crates/polars-ops/src/frame/join/general.rs b/crates/polars-ops/src/frame/join/general.rs index eb8c6dfdb0d6f..74f837c849ec5 100644 --- a/crates/polars-ops/src/frame/join/general.rs +++ b/crates/polars-ops/src/frame/join/general.rs @@ -1,8 +1,3 @@ -use std::borrow::Cow; - -#[cfg(feature = "chunked_ids")] -use polars_utils::index::ChunkId; - use super::*; use crate::series::coalesce_series; @@ -37,7 +32,13 @@ pub fn _finish_join( let suffix = get_suffix(suffix); for name in rename_strs { - df_right.rename(&name, &_join_suffix_name(&name, suffix))?; + let new_name = _join_suffix_name(&name, suffix); + df_right.rename(&name, new_name.as_str()).map_err(|_| { + polars_err!(Duplicate: "column with name '{}' already exists\n\n\ + You may want to try:\n\ + - renaming the column prior to joining\n\ + - using the `suffix` parameter to specify a suffix different to the default one ('_right')", new_name) + })?; } drop(left_names); @@ -45,7 +46,7 @@ pub fn _finish_join( Ok(df_left) } -pub(super) fn coalesce_outer_join( +pub fn _coalesce_outer_join( mut df: DataFrame, keys_left: &[&str], keys_right: &[&str], diff --git a/crates/polars-ops/src/frame/join/hash_join/mod.rs b/crates/polars-ops/src/frame/join/hash_join/mod.rs index 30cd012de9cd0..f07667130cc52 100644 --- a/crates/polars-ops/src/frame/join/hash_join/mod.rs +++ b/crates/polars-ops/src/frame/join/hash_join/mod.rs @@ -10,9 +10,7 @@ pub(super) mod sort_merge; use arrow::array::ArrayRef; pub use multiple_keys::private_left_join_multiple_keys; pub(super) use multiple_keys::*; -#[cfg(any(feature = "chunked_ids", feature = "semi_anti_join"))] -use polars_core::utils::slice_slice; -use polars_core::utils::{_set_partition_size, slice_offsets, split_ca}; +use polars_core::utils::{_set_partition_size, split_ca}; use polars_core::POOL; use polars_utils::index::ChunkId; pub(super) use single_keys::*; @@ -56,8 +54,6 @@ macro_rules! det_hash_prone_order { use arrow::legacy::conversion::primitive_to_vec; pub(super) use det_hash_prone_order; -use crate::frame::join::general::coalesce_outer_join; - pub trait JoinDispatch: IntoDf { /// # Safety /// Join tuples must be in bounds @@ -114,7 +110,7 @@ pub trait JoinDispatch: IntoDf { let materialize_right = || { let right_idx = &*right_idx; - unsafe { other.take_unchecked(&right_idx.iter().copied().collect_ca("")) } + unsafe { IdxCa::with_nullable_idx(right_idx, |idx| other.take_unchecked(idx)) } }; let (df_left, df_right) = POOL.join(materialize_left, materialize_right); @@ -154,7 +150,7 @@ pub trait JoinDispatch: IntoDf { if let Some((offset, len)) = args.slice { right_idx = slice_slice(right_idx, offset, len); } - other.take_unchecked(&right_idx.iter().copied().collect_ca("")) + IdxCa::with_nullable_idx(right_idx, |idx| other.take_unchecked(idx)) }, ChunkJoinOptIds::Right(right_idx) => unsafe { let mut right_idx = &*right_idx; @@ -177,11 +173,11 @@ pub trait JoinDispatch: IntoDf { args: JoinArgs, verbose: bool, ) -> PolarsResult { - let ca_self = self.to_df(); + let df_self = self.to_df(); #[cfg(feature = "dtype-categorical")] _check_categorical_src(s_left.dtype(), s_right.dtype())?; - let mut left = ca_self.clone(); + let mut left = df_self.clone(); let mut s_left = s_left.clone(); // Eagerly limit left if possible. if let Some((offset, len)) = args.slice { @@ -192,16 +188,19 @@ pub trait JoinDispatch: IntoDf { } // Ensure that the chunks are aligned otherwise we go OOB. - let mut right = other.clone(); + let mut right = Cow::Borrowed(other); let mut s_right = s_right.clone(); if left.should_rechunk() { left.as_single_chunk_par(); s_left = s_left.rechunk(); } if right.should_rechunk() { - right.as_single_chunk_par(); + let mut other = other.clone(); + other.as_single_chunk_par(); + right = Cow::Owned(other); s_right = s_right.rechunk(); } + let ids = sort_or_hash_left(&s_left, &s_right, verbose, args.validation, args.join_nulls)?; left._finish_left_join(ids, &right.drop(s_right.name()).unwrap(), args) } @@ -273,7 +272,7 @@ pub trait JoinDispatch: IntoDf { }; let out = _finish_join(df_left, df_right, args.suffix.as_deref()); if coalesce { - Ok(coalesce_outer_join( + Ok(_coalesce_outer_join( out?, &[s_left.name()], &[s_right.name()], diff --git a/crates/polars-ops/src/frame/join/hash_join/multiple_keys.rs b/crates/polars-ops/src/frame/join/hash_join/multiple_keys.rs index be380c8a2de5c..119973c9671eb 100644 --- a/crates/polars-ops/src/frame/join/hash_join/multiple_keys.rs +++ b/crates/polars-ops/src/frame/join/hash_join/multiple_keys.rs @@ -1,11 +1,7 @@ use arrow::array::{MutablePrimitiveArray, PrimitiveArray}; -use hashbrown::hash_map::RawEntryMut; use hashbrown::HashMap; -use polars_core::hashing::{ - populate_multiple_key_hashmap, IdBuildHasher, IdxHash, _HASHMAP_INIT_SIZE, -}; -use polars_core::utils::{_set_partition_size, split_df}; -use polars_core::POOL; +use polars_core::hashing::{populate_multiple_key_hashmap, IdBuildHasher, IdxHash}; +use polars_core::utils::split_df; use polars_utils::hashing::hash_to_partition; use polars_utils::idx_vec::IdxVec; use polars_utils::unitvec; @@ -88,40 +84,39 @@ fn create_build_table_outer( // We will create a hashtable in every thread. // We use the hash to partition the keys to the matching hashtable. // Every thread traverses all keys/hashes and ignores the ones that doesn't fall in that partition. - POOL.install(|| { - (0..n_partitions).into_par_iter().map(|part_no| { - let mut hash_tbl: HashMap = - HashMap::with_capacity_and_hasher(_HASHMAP_INIT_SIZE, Default::default()); - - let mut offset = 0; - for hashes in hashes { - for hashes in hashes.data_views() { - let len = hashes.len(); - let mut idx = 0; - hashes.iter().for_each(|h| { - // partition hashes by thread no. - // So only a part of the hashes go to this hashmap - if part_no == hash_to_partition(*h, n_partitions) { - let idx = idx + offset; - populate_multiple_key_hashmap( - &mut hash_tbl, - idx, - *h, - keys, - || (false, unitvec![idx]), - |v| v.1.push(idx), - ) - } - idx += 1; - }); + let par_iter = (0..n_partitions).into_par_iter().map(|part_no| { + let mut hash_tbl: HashMap = + HashMap::with_capacity_and_hasher(_HASHMAP_INIT_SIZE, Default::default()); + + let mut offset = 0; + for hashes in hashes { + for hashes in hashes.data_views() { + let len = hashes.len(); + let mut idx = 0; + hashes.iter().for_each(|h| { + // partition hashes by thread no. + // So only a part of the hashes go to this hashmap + if part_no == hash_to_partition(*h, n_partitions) { + let idx = idx + offset; + populate_multiple_key_hashmap( + &mut hash_tbl, + idx, + *h, + keys, + || (false, unitvec![idx]), + |v| v.1.push(idx), + ) + } + idx += 1; + }); - offset += len as IdxSize; - } + offset += len as IdxSize; } - hash_tbl - }) - }) - .collect() + } + hash_tbl + }); + + POOL.install(|| par_iter.collect()) } /// Probe the build table and add tuples to the results (inner join) @@ -251,8 +246,8 @@ pub fn private_left_join_multiple_keys( chunk_mapping_right: Option<&[ChunkId]>, join_nulls: bool, ) -> LeftJoinIds { - let mut a = DataFrame::new_no_checks(_to_physical_and_bit_repr(a.get_columns())); - let mut b = DataFrame::new_no_checks(_to_physical_and_bit_repr(b.get_columns())); + let mut a = unsafe { DataFrame::new_no_checks(_to_physical_and_bit_repr(a.get_columns())) }; + let mut b = unsafe { DataFrame::new_no_checks(_to_physical_and_bit_repr(b.get_columns())) }; _left_join_multiple_keys( &mut a, &mut b, @@ -326,12 +321,13 @@ pub fn _left_join_multiple_keys( Some((_, indexes_b)) => { result_idx_left .extend(std::iter::repeat(idx_a).take(indexes_b.len())); - result_idx_right.extend(indexes_b.iter().copied().map(Some)) + let indexes_b = bytemuck::cast_slice(indexes_b); + result_idx_right.extend_from_slice(indexes_b); }, // only left values, right = null None => { result_idx_left.push(idx_a); - result_idx_right.push(None); + result_idx_right.push(NullableIdxSize::null()); }, } idx_a += 1; @@ -360,40 +356,32 @@ pub(crate) fn create_build_table_semi_anti( // We will create a hashtable in every thread. // We use the hash to partition the keys to the matching hashtable. // Every thread traverses all keys/hashes and ignores the ones that doesn't fall in that partition. - POOL.install(|| { - (0..n_partitions).into_par_iter().map(|part_no| { - let mut hash_tbl: HashMap = - HashMap::with_capacity_and_hasher(_HASHMAP_INIT_SIZE, Default::default()); - - let mut offset = 0; - for hashes in hashes { - for hashes in hashes.data_views() { - let len = hashes.len(); - let mut idx = 0; - hashes.iter().for_each(|h| { - // partition hashes by thread no. - // So only a part of the hashes go to this hashmap - if part_no == hash_to_partition(*h, n_partitions) { - let idx = idx + offset; - populate_multiple_key_hashmap( - &mut hash_tbl, - idx, - *h, - keys, - || (), - |_| (), - ) - } - idx += 1; - }); + let par_iter = (0..n_partitions).into_par_iter().map(|part_no| { + let mut hash_tbl: HashMap = + HashMap::with_capacity_and_hasher(_HASHMAP_INIT_SIZE, Default::default()); + + let mut offset = 0; + for hashes in hashes { + for hashes in hashes.data_views() { + let len = hashes.len(); + let mut idx = 0; + hashes.iter().for_each(|h| { + // partition hashes by thread no. + // So only a part of the hashes go to this hashmap + if part_no == hash_to_partition(*h, n_partitions) { + let idx = idx + offset; + populate_multiple_key_hashmap(&mut hash_tbl, idx, *h, keys, || (), |_| ()) + } + idx += 1; + }); - offset += len as IdxSize; - } + offset += len as IdxSize; } - hash_tbl - }) - }) - .collect() + } + hash_tbl + }); + + POOL.install(|| par_iter.collect()) } #[cfg(feature = "semi_anti_join")] @@ -423,46 +411,43 @@ pub(crate) fn semi_anti_join_multiple_keys_impl<'a>( // next we probe the other relation // code duplication is because we want to only do the swap check once - POOL.install(move || { - probe_hashes - .into_par_iter() - .zip(offsets) - .flat_map(move |(probe_hashes, offset)| { - // local reference - let hash_tbls = &hash_tbls; - let mut results = - Vec::with_capacity(probe_hashes.len() / POOL.current_num_threads()); - let local_offset = offset; - - let mut idx_a = local_offset as IdxSize; - for probe_hashes in probe_hashes.data_views() { - for &h in probe_hashes { - // probe table that contains the hashed value - let current_probe_table = - unsafe { hash_tbls.get_unchecked(hash_to_partition(h, n_tables)) }; - - let entry = current_probe_table.raw_entry().from_hash(h, |idx_hash| { - let idx_b = idx_hash.idx; - // SAFETY: - // indices in a join operation are always in bounds. - unsafe { - compare_df_rows2(a, b, idx_a as usize, idx_b as usize, join_nulls) - } - }); - - match entry { - // left and right matches - Some((_, _)) => results.push((idx_a, true)), - // only left values, right = null - None => results.push((idx_a, false)), + probe_hashes + .into_par_iter() + .zip(offsets) + .flat_map(move |(probe_hashes, offset)| { + // local reference + let hash_tbls = &hash_tbls; + let mut results = Vec::with_capacity(probe_hashes.len() / POOL.current_num_threads()); + let local_offset = offset; + + let mut idx_a = local_offset as IdxSize; + for probe_hashes in probe_hashes.data_views() { + for &h in probe_hashes { + // probe table that contains the hashed value + let current_probe_table = + unsafe { hash_tbls.get_unchecked(hash_to_partition(h, n_tables)) }; + + let entry = current_probe_table.raw_entry().from_hash(h, |idx_hash| { + let idx_b = idx_hash.idx; + // SAFETY: + // indices in a join operation are always in bounds. + unsafe { + compare_df_rows2(a, b, idx_a as usize, idx_b as usize, join_nulls) } - idx_a += 1; + }); + + match entry { + // left and right matches + Some((_, _)) => results.push((idx_a, true)), + // only left values, right = null + None => results.push((idx_a, false)), } + idx_a += 1; } + } - results - }) - }) + results + }) } #[cfg(feature = "semi_anti_join")] @@ -471,10 +456,10 @@ pub fn _left_anti_multiple_keys( b: &mut DataFrame, join_nulls: bool, ) -> Vec { - semi_anti_join_multiple_keys_impl(a, b, join_nulls) + let par_iter = semi_anti_join_multiple_keys_impl(a, b, join_nulls) .filter(|tpls| !tpls.1) - .map(|tpls| tpls.0) - .collect() + .map(|tpls| tpls.0); + POOL.install(|| par_iter.collect()) } #[cfg(feature = "semi_anti_join")] @@ -483,10 +468,10 @@ pub fn _left_semi_multiple_keys( b: &mut DataFrame, join_nulls: bool, ) -> Vec { - semi_anti_join_multiple_keys_impl(a, b, join_nulls) + let par_iter = semi_anti_join_multiple_keys_impl(a, b, join_nulls) .filter(|tpls| tpls.1) - .map(|tpls| tpls.0) - .collect() + .map(|tpls| tpls.0); + POOL.install(|| par_iter.collect()) } /// Probe the build table and add tuples to the results (inner join) diff --git a/crates/polars-ops/src/frame/join/hash_join/single_keys.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys.rs index ee92dfdd6c455..38b59c2d7454c 100644 --- a/crates/polars-ops/src/frame/join/hash_join/single_keys.rs +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys.rs @@ -2,6 +2,7 @@ use polars_utils::hashing::{hash_to_partition, DirtyHash}; use polars_utils::idx_vec::IdxVec; use polars_utils::nulls::IsNull; use polars_utils::sync::SyncPtr; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; use polars_utils::unitvec; use super::*; @@ -12,9 +13,13 @@ use super::*; // Use a small element per thread threshold for debugging/testing purposes. const MIN_ELEMS_PER_THREAD: usize = if cfg!(debug_assertions) { 1 } else { 128 }; -pub(crate) fn build_tables(keys: Vec, join_nulls: bool) -> Vec> +pub(crate) fn build_tables( + keys: Vec, + join_nulls: bool, +) -> Vec::TotalOrdItem, IdxVec>> where - T: Send + Hash + Eq + Sync + Copy + DirtyHash + IsNull, + T: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash + IsNull, I: IntoIterator + Send + Sync + Clone, { // FIXME: change interface to split the input here, instead of taking @@ -28,10 +33,11 @@ where // Don't bother parallelizing anything for small inputs. if num_keys_est < 2 * MIN_ELEMS_PER_THREAD { - let mut hm: PlHashMap = PlHashMap::new(); + let mut hm: PlHashMap = PlHashMap::new(); let mut offset = 0; for it in keys { for k in it { + let k = k.to_total_ord(); if !k.is_null() || join_nulls { hm.entry(k).or_default().push(offset); } @@ -49,6 +55,7 @@ where .map(|key_portion| { let mut partition_sizes = vec![0; n_partitions]; for key in key_portion.clone() { + let key = key.to_total_ord(); let p = hash_to_partition(key.dirty_hash(), n_partitions); unsafe { *partition_sizes.get_unchecked_mut(p) += 1; @@ -85,7 +92,7 @@ where } // Scatter values into partitions. - let mut scatter_keys: Vec = Vec::with_capacity(num_keys); + let mut scatter_keys: Vec = Vec::with_capacity(num_keys); let mut scatter_idxs: Vec = Vec::with_capacity(num_keys); let scatter_keys_ptr = unsafe { SyncPtr::new(scatter_keys.as_mut_ptr()) }; let scatter_idxs_ptr = unsafe { SyncPtr::new(scatter_idxs.as_mut_ptr()) }; @@ -96,6 +103,7 @@ where let mut partition_offsets = per_thread_partition_offsets[t * n_partitions..(t + 1) * n_partitions].to_vec(); for (i, key) in key_portion.into_iter().enumerate() { + let key = key.to_total_ord(); unsafe { let p = hash_to_partition(key.dirty_hash(), n_partitions); let off = partition_offsets.get_unchecked_mut(p); @@ -124,7 +132,8 @@ where let partition_range = partition_offsets[p]..partition_offsets[p + 1]; let full_size = partition_range.len(); let mut conservative_size = _HASHMAP_INIT_SIZE.max(full_size / 64); - let mut hm: PlHashMap = PlHashMap::with_capacity(conservative_size); + let mut hm: PlHashMap = + PlHashMap::with_capacity(conservative_size); unsafe { for i in partition_range { @@ -160,8 +169,6 @@ where pub(super) fn probe_to_offsets(probe: &[I]) -> Vec where I: IntoIterator + Clone, - // ::IntoIter: TrustedLen, - T: Send + Hash + Eq + Sync + Copy + DirtyHash, { probe .iter() diff --git a/crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs index dca9e13260970..9468ac483d3dd 100644 --- a/crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys_dispatch.rs @@ -1,7 +1,8 @@ use arrow::array::PrimitiveArray; -use num_traits::NumCast; +use polars_core::with_match_physical_float_polars_type; use polars_utils::hashing::DirtyHash; use polars_utils::nulls::IsNull; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; use super::*; use crate::series::SeriesSealed; @@ -28,6 +29,12 @@ pub trait SeriesJoin: SeriesSealed + Sized { let lhs = lhs.iter().map(|v| v.as_slice()).collect::>(); let rhs = rhs.iter().map(|v| v.as_slice()).collect::>(); hash_join_tuples_left(lhs, rhs, None, None, validate, join_nulls) + } else if lhs.dtype().is_float() { + with_match_physical_float_polars_type!(lhs.dtype(), |$T| { + let lhs: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref(); + let rhs: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref(); + num_group_join_left(lhs, rhs, validate, join_nulls) + }) } else if s_self.bit_repr_is_large() { let lhs = lhs.bit_repr_large(); let rhs = rhs.bit_repr_large(); @@ -58,6 +65,12 @@ pub trait SeriesJoin: SeriesSealed + Sized { } else { hash_join_tuples_left_semi(lhs, rhs) } + } else if lhs.dtype().is_float() { + with_match_physical_float_polars_type!(lhs.dtype(), |$T| { + let lhs: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref(); + let rhs: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref(); + num_group_join_anti_semi(lhs, rhs, anti) + }) } else if s_self.bit_repr_is_large() { let lhs = lhs.bit_repr_large(); let rhs = rhs.bit_repr_large(); @@ -93,6 +106,12 @@ pub trait SeriesJoin: SeriesSealed + Sized { hash_join_tuples_inner(lhs, rhs, swapped, validate, join_nulls)?, !swapped, )) + } else if lhs.dtype().is_float() { + with_match_physical_float_polars_type!(lhs.dtype(), |$T| { + let lhs: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref(); + let rhs: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref(); + group_join_inner::<$T>(lhs, rhs, validate, join_nulls) + }) } else if s_self.bit_repr_is_large() { let lhs = s_self.bit_repr_large(); let rhs = other.bit_repr_large(); @@ -124,6 +143,12 @@ pub trait SeriesJoin: SeriesSealed + Sized { let lhs = lhs.iter().collect::>(); let rhs = rhs.iter().collect::>(); hash_join_tuples_outer(lhs, rhs, swapped, validate, join_nulls) + } else if lhs.dtype().is_float() { + with_match_physical_float_polars_type!(lhs.dtype(), |$T| { + let lhs: &ChunkedArray<$T> = lhs.as_ref().as_ref().as_ref(); + let rhs: &ChunkedArray<$T> = rhs.as_ref().as_ref().as_ref(); + hash_join_outer(lhs, rhs, validate, join_nulls) + }) } else if s_self.bit_repr_is_large() { let lhs = s_self.bit_repr_large(); let rhs = other.bit_repr_large(); @@ -161,7 +186,10 @@ fn group_join_inner( where T: PolarsDataType, for<'a> &'a T::Array: IntoIterator>>, - for<'a> T::Physical<'a>: Hash + Eq + Send + DirtyHash + Copy + Send + Sync + IsNull, + for<'a> T::Physical<'a>: + Send + Sync + Copy + TotalHash + TotalEq + DirtyHash + IsNull + ToTotalOrd, + for<'a> as ToTotalOrd>::TotalOrdItem: + Send + Sync + Copy + Hash + Eq + DirtyHash + IsNull, { let n_threads = POOL.current_num_threads(); let (a, b, swapped) = det_hash_prone_order!(left, right); @@ -243,9 +271,11 @@ fn num_group_join_left( join_nulls: bool, ) -> PolarsResult where - T: PolarsIntegerType, - T::Native: Hash + Eq + Send + DirtyHash + IsNull, - Option: DirtyHash, + T: PolarsNumericType, + T::Native: TotalHash + TotalEq + DirtyHash + IsNull + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash + IsNull, + T::Native: DirtyHash + Copy + ToTotalOrd, + as ToTotalOrd>::TotalOrdItem: Send + Sync + DirtyHash, { let n_threads = POOL.current_num_threads(); let splitted_a = split_ca(left, n_threads).unwrap(); @@ -300,8 +330,9 @@ fn hash_join_outer( join_nulls: bool, ) -> PolarsResult<(PrimitiveArray, PrimitiveArray)> where - T: PolarsIntegerType + Sync, - T::Native: Eq + Hash + NumCast, + T: PolarsNumericType, + T::Native: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Copy + Hash + Eq + IsNull, { let (a, b, swapped) = det_hash_prone_order!(ca_in, other); @@ -395,9 +426,10 @@ fn num_group_join_anti_semi( anti: bool, ) -> Vec where - T: PolarsIntegerType, - T::Native: Hash + Eq + Send + DirtyHash, - Option: DirtyHash, + T: PolarsNumericType, + T::Native: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash, + as ToTotalOrd>::TotalOrdItem: Send + Sync + DirtyHash, { let n_threads = POOL.current_num_threads(); let splitted_a = split_ca(left, n_threads).unwrap(); diff --git a/crates/polars-ops/src/frame/join/hash_join/single_keys_inner.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys_inner.rs index bc5e5d4acdce2..58bdd286a8145 100644 --- a/crates/polars-ops/src/frame/join/hash_join/single_keys_inner.rs +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys_inner.rs @@ -4,23 +4,25 @@ use polars_utils::idx_vec::IdxVec; use polars_utils::iter::EnumerateIdxTrait; use polars_utils::nulls::IsNull; use polars_utils::sync::SyncPtr; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; use super::*; pub(super) fn probe_inner( probe: I, - hash_tbls: &[PlHashMap], + hash_tbls: &[PlHashMap<::TotalOrdItem, IdxVec>], results: &mut Vec<(IdxSize, IdxSize)>, local_offset: IdxSize, n_tables: usize, swap_fn: F, ) where - T: Send + Hash + Eq + Sync + Copy + DirtyHash, + T: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Hash + Eq + DirtyHash, I: IntoIterator, - // ::IntoIter: TrustedLen, F: Fn(IdxSize, IdxSize) -> (IdxSize, IdxSize), { probe.into_iter().enumerate_idx().for_each(|(idx_a, k)| { + let k = k.to_total_ord(); let idx_a = idx_a + local_offset; // probe table that contains the hashed value let current_probe_table = @@ -45,8 +47,8 @@ pub(super) fn hash_join_tuples_inner( ) -> PolarsResult<(Vec, Vec)> where I: IntoIterator + Send + Sync + Clone, - // ::IntoIter: TrustedLen, - T: Send + Hash + Eq + Sync + Copy + DirtyHash + IsNull, + T: Send + Sync + Copy + TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash + IsNull, { // NOTE: see the left join for more elaborate comments // first we hash one relation diff --git a/crates/polars-ops/src/frame/join/hash_join/single_keys_left.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys_left.rs index 51956d41585dc..91c4f0cd10082 100644 --- a/crates/polars-ops/src/frame/join/hash_join/single_keys_left.rs +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys_left.rs @@ -1,6 +1,7 @@ use polars_core::utils::flatten::flatten_par; use polars_utils::hashing::{hash_to_partition, DirtyHash}; use polars_utils::nulls::IsNull; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; use super::*; @@ -12,19 +13,22 @@ unsafe fn apply_mapping(idx: Vec, chunk_mapping: &[ChunkId]) -> Vec>, - chunk_mapping: &[ChunkId], -) -> Vec> { +unsafe fn apply_opt_mapping(idx: Vec, chunk_mapping: &[ChunkId]) -> Vec { idx.iter() - .map(|opt_idx| opt_idx.map(|idx| *chunk_mapping.get_unchecked(idx as usize))) + .map(|opt_idx| { + if opt_idx.is_null_idx() { + ChunkId::null() + } else { + *chunk_mapping.get_unchecked(opt_idx.idx() as usize) + } + }) .collect() } #[cfg(feature = "chunked_ids")] pub(super) fn finish_left_join_mappings( result_idx_left: Vec, - result_idx_right: Vec>, + result_idx_right: Vec, chunk_mapping_left: Option<&[ChunkId]>, chunk_mapping_right: Option<&[ChunkId]>, ) -> LeftJoinIds { @@ -45,7 +49,7 @@ pub(super) fn finish_left_join_mappings( #[cfg(not(feature = "chunked_ids"))] pub(super) fn finish_left_join_mappings( _result_idx_left: Vec, - _result_idx_right: Vec>, + _result_idx_right: Vec, _chunk_mapping_left: Option<&[ChunkId]>, _chunk_mapping_right: Option<&[ChunkId]>, ) -> LeftJoinIds { @@ -112,7 +116,8 @@ pub(super) fn hash_join_tuples_left( where I: IntoIterator, ::IntoIter: Send + Sync + Clone, - T: Send + Hash + Eq + Sync + Copy + DirtyHash + IsNull, + T: Send + Sync + Copy + TotalHash + TotalEq + DirtyHash + IsNull + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Copy + Hash + Eq + DirtyHash + IsNull, { let probe = probe.into_iter().map(|i| i.into_iter()).collect::>(); let build = build.into_iter().map(|i| i.into_iter()).collect::>(); @@ -147,6 +152,7 @@ where let mut result_idx_right = Vec::with_capacity(probe.size_hint().1.unwrap()); probe.enumerate().for_each(|(idx_a, k)| { + let k = k.to_total_ord(); let idx_a = (idx_a + offset) as IdxSize; // probe table that contains the hashed value let current_probe_table = unsafe { @@ -160,12 +166,12 @@ where // left and right matches Some(indexes_b) => { result_idx_left.extend(std::iter::repeat(idx_a).take(indexes_b.len())); - result_idx_right.extend(indexes_b.iter().copied().map(Some)) + result_idx_right.extend_from_slice(bytemuck::cast_slice(indexes_b)); }, // only left values, right = null None => { result_idx_left.push(idx_a); - result_idx_right.push(None); + result_idx_right.push(NullableIdxSize::null()); }, } }); diff --git a/crates/polars-ops/src/frame/join/hash_join/single_keys_outer.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys_outer.rs index 33c4a376de876..f2e0f21ad8f70 100644 --- a/crates/polars-ops/src/frame/join/hash_join/single_keys_outer.rs +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys_outer.rs @@ -3,6 +3,7 @@ use arrow::legacy::utils::CustomIterTools; use polars_utils::hashing::hash_to_partition; use polars_utils::idx_vec::IdxVec; use polars_utils::nulls::IsNull; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; use polars_utils::unitvec; use super::*; @@ -14,7 +15,8 @@ pub(crate) fn create_hash_and_keys_threaded_vectorized( where I: IntoIterator + Send, I::IntoIter: TrustedLen, - T: Send + Hash + Eq, + T: TotalHash + TotalEq + Send + ToTotalOrd, + ::TotalOrdItem: Hash + Eq, { let build_hasher = build_hasher.unwrap_or_default(); let hashes = POOL.install(|| { @@ -23,7 +25,7 @@ where .map(|iter| { // create hashes and keys iter.into_iter() - .map(|val| (build_hasher.hash_one(&val), val)) + .map(|val| (build_hasher.hash_one(&val.to_total_ord()), val)) .collect_trusted::>() }) .collect() @@ -33,10 +35,11 @@ where pub(crate) fn prepare_hashed_relation_threaded( iters: Vec, -) -> Vec> +) -> Vec::TotalOrdItem, (bool, IdxVec)>> where I: Iterator + Send + TrustedLen, - T: Send + Hash + Eq + Sync + Copy, + T: Send + Sync + TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Hash + Eq, { let n_partitions = _set_partition_size(); let (hashes_and_keys, build_hasher) = create_hash_and_keys_threaded_vectorized(iters, None); @@ -50,7 +53,7 @@ where .map(|partition_no| { let build_hasher = build_hasher.clone(); let hashes_and_keys = &hashes_and_keys; - let mut hash_tbl: PlHashMap = + let mut hash_tbl: PlHashMap = PlHashMap::with_hasher(build_hasher); let mut offset = 0; @@ -60,6 +63,7 @@ where .iter() .enumerate() .for_each(|(idx, (h, k))| { + let k = k.to_total_ord(); let idx = idx as IdxSize; // partition hashes by thread no. // So only a part of the hashes go to this hashmap @@ -68,11 +72,11 @@ where let entry = hash_tbl .raw_entry_mut() // uses the key to check equality to find and entry - .from_key_hashed_nocheck(*h, k); + .from_key_hashed_nocheck(*h, &k); match entry { RawEntryMut::Vacant(entry) => { - entry.insert_hashed_nocheck(*h, *k, (false, unitvec![idx])); + entry.insert_hashed_nocheck(*h, k, (false, unitvec![idx])); }, RawEntryMut::Occupied(mut entry) => { let (_k, v) = entry.get_key_value_mut(); @@ -94,7 +98,7 @@ where #[allow(clippy::too_many_arguments)] fn probe_outer( probe_hashes: &[Vec<(u64, T)>], - hash_tbls: &mut [PlHashMap], + hash_tbls: &mut [PlHashMap<::TotalOrdItem, (bool, IdxVec)>], results: &mut ( MutablePrimitiveArray, MutablePrimitiveArray, @@ -108,7 +112,8 @@ fn probe_outer( swap_fn_drain: H, join_nulls: bool, ) where - T: Send + Hash + Eq + Sync + Copy + IsNull, + T: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Hash + Eq + IsNull, // idx_a, idx_b -> ... F: Fn(IdxSize, IdxSize) -> (Option, Option), // idx_a -> ... @@ -120,6 +125,7 @@ fn probe_outer( let mut idx_a = 0; for probe_hashes in probe_hashes { for (h, key) in probe_hashes { + let key = key.to_total_ord(); let h = *h; // probe table that contains the hashed value let current_probe_table = @@ -127,7 +133,7 @@ fn probe_outer( let entry = current_probe_table .raw_entry_mut() - .from_key_hashed_nocheck(h, key); + .from_key_hashed_nocheck(h, &key); match entry { // match and remove @@ -182,7 +188,8 @@ where J: IntoIterator, ::IntoIter: TrustedLen + Send, ::IntoIter: TrustedLen + Send, - T: Hash + Eq + Copy + Sync + Send + IsNull, + T: Send + Sync + TotalHash + TotalEq + IsNull + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Hash + Eq + IsNull, { let probe = probe.into_iter().map(|i| i.into_iter()).collect::>(); let build = build.into_iter().map(|i| i.into_iter()).collect::>(); diff --git a/crates/polars-ops/src/frame/join/hash_join/single_keys_semi_anti.rs b/crates/polars-ops/src/frame/join/hash_join/single_keys_semi_anti.rs index 93268036c43dd..57196e86632d0 100644 --- a/crates/polars-ops/src/frame/join/hash_join/single_keys_semi_anti.rs +++ b/crates/polars-ops/src/frame/join/hash_join/single_keys_semi_anti.rs @@ -1,11 +1,15 @@ use polars_utils::hashing::{hash_to_partition, DirtyHash}; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; use super::*; /// Only keeps track of membership in right table -pub(super) fn create_probe_table_semi_anti(keys: Vec) -> Vec> +pub(super) fn create_probe_table_semi_anti( + keys: Vec, +) -> Vec::TotalOrdItem>> where - T: Send + Hash + Eq + Sync + Copy + DirtyHash, + T: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Hash + Eq + DirtyHash, I: IntoIterator + Copy + Send + Sync, { let n_partitions = _set_partition_size(); @@ -13,29 +17,31 @@ where // We will create a hashtable in every thread. // We use the hash to partition the keys to the matching hashtable. // Every thread traverses all keys/hashes and ignores the ones that doesn't fall in that partition. - POOL.install(|| { - (0..n_partitions).into_par_iter().map(|partition_no| { - let mut hash_tbl: PlHashSet = PlHashSet::with_capacity(_HASHMAP_INIT_SIZE); - for keys in &keys { - keys.into_iter().for_each(|k| { - if partition_no == hash_to_partition(k.dirty_hash(), n_partitions) { - hash_tbl.insert(k); - } - }); - } - hash_tbl - }) - }) - .collect() + let par_iter = (0..n_partitions).into_par_iter().map(|partition_no| { + let mut hash_tbl: PlHashSet = PlHashSet::with_capacity(_HASHMAP_INIT_SIZE); + for keys in &keys { + keys.into_iter().for_each(|k| { + let k = k.to_total_ord(); + if partition_no == hash_to_partition(k.dirty_hash(), n_partitions) { + hash_tbl.insert(k); + } + }); + } + hash_tbl + }); + POOL.install(|| par_iter.collect()) } -pub(super) fn semi_anti_impl( +/// Construct a ParallelIterator, but doesn't iterate it. This means the caller +/// context (or wherever it gets iterated) should be in POOL.install. +fn semi_anti_impl( probe: Vec, build: Vec, ) -> impl ParallelIterator where I: IntoIterator + Copy + Send + Sync, - T: Send + Hash + Eq + Sync + Copy + DirtyHash, + T: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Hash + Eq + DirtyHash, { // first we hash one relation let hash_sets = create_probe_table_semi_anti(build); @@ -46,60 +52,61 @@ where let n_tables = hash_sets.len(); // next we probe the other relation - POOL.install(move || { - probe - .into_par_iter() - .zip(offsets) - // probes_hashes: Vec processed by this thread - // offset: offset index - .flat_map(move |(probe, offset)| { - // local reference - let hash_sets = &hash_sets; - let probe_iter = probe.into_iter(); + // This is not wrapped in POOL.install because it is not being iterated here + probe + .into_par_iter() + .zip(offsets) + // probes_hashes: Vec processed by this thread + // offset: offset index + .flat_map(move |(probe, offset)| { + // local reference + let hash_sets = &hash_sets; + let probe_iter = probe.into_iter(); - // assume the result tuples equal length of the no. of hashes processed by this thread. - let mut results = Vec::with_capacity(probe_iter.size_hint().1.unwrap()); + // assume the result tuples equal length of the no. of hashes processed by this thread. + let mut results = Vec::with_capacity(probe_iter.size_hint().1.unwrap()); - probe_iter.enumerate().for_each(|(idx_a, k)| { - let idx_a = (idx_a + offset) as IdxSize; - // probe table that contains the hashed value - let current_probe_table = unsafe { - hash_sets.get_unchecked(hash_to_partition(k.dirty_hash(), n_tables)) - }; + probe_iter.enumerate().for_each(|(idx_a, k)| { + let k = k.to_total_ord(); + let idx_a = (idx_a + offset) as IdxSize; + // probe table that contains the hashed value + let current_probe_table = + unsafe { hash_sets.get_unchecked(hash_to_partition(k.dirty_hash(), n_tables)) }; - // we already hashed, so we don't have to hash again. - let value = current_probe_table.get(&k); + // we already hashed, so we don't have to hash again. + let value = current_probe_table.get(&k); - match value { - // left and right matches - Some(_) => results.push((idx_a, true)), - // only left values, right = null - None => results.push((idx_a, false)), - } - }); - results - }) - }) + match value { + // left and right matches + Some(_) => results.push((idx_a, true)), + // only left values, right = null + None => results.push((idx_a, false)), + } + }); + results + }) } pub(super) fn hash_join_tuples_left_anti(probe: Vec, build: Vec) -> Vec where I: IntoIterator + Copy + Send + Sync, - T: Send + Hash + Eq + Sync + Copy + DirtyHash, + T: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Hash + Eq + DirtyHash, { - semi_anti_impl(probe, build) + let par_iter = semi_anti_impl(probe, build) .filter(|tpls| !tpls.1) - .map(|tpls| tpls.0) - .collect() + .map(|tpls| tpls.0); + POOL.install(|| par_iter.collect()) } pub(super) fn hash_join_tuples_left_semi(probe: Vec, build: Vec) -> Vec where I: IntoIterator + Copy + Send + Sync, - T: Send + Hash + Eq + Sync + Copy + DirtyHash, + T: TotalHash + TotalEq + DirtyHash + ToTotalOrd, + ::TotalOrdItem: Send + Sync + Hash + Eq + DirtyHash, { - semi_anti_impl(probe, build) + let par_iter = semi_anti_impl(probe, build) .filter(|tpls| tpls.1) - .map(|tpls| tpls.0) - .collect() + .map(|tpls| tpls.0); + POOL.install(|| par_iter.collect()) } diff --git a/crates/polars-ops/src/frame/join/hash_join/sort_merge.rs b/crates/polars-ops/src/frame/join/hash_join/sort_merge.rs index d9b849ce1e599..c19d6f5f9de99 100644 --- a/crates/polars-ops/src/frame/join/hash_join/sort_merge.rs +++ b/crates/polars-ops/src/frame/join/hash_join/sort_merge.rs @@ -11,7 +11,7 @@ use super::*; fn par_sorted_merge_left_impl( s_left: &ChunkedArray, s_right: &ChunkedArray, -) -> (Vec, Vec>) +) -> (Vec, Vec) where T: PolarsNumericType, { @@ -23,13 +23,12 @@ where let slice_left = s_left.cont_slice().unwrap(); let slice_right = s_right.cont_slice().unwrap(); - let indexes = offsets - .into_par_iter() - .map(|(offset, len)| { - let slice_left = &slice_left[offset..offset + len]; - sorted_join::left::join(slice_left, slice_right, offset as IdxSize) - }) - .collect::>(); + let indexes = offsets.into_par_iter().map(|(offset, len)| { + let slice_left = &slice_left[offset..offset + len]; + sorted_join::left::join(slice_left, slice_right, offset as IdxSize) + }); + let indexes = POOL.install(|| indexes.collect::>()); + let lefts = indexes.iter().map(|t| &t.0).collect::>(); let rights = indexes.iter().map(|t| &t.1).collect::>(); @@ -40,7 +39,7 @@ where pub(super) fn par_sorted_merge_left( s_left: &Series, s_right: &Series, -) -> (Vec, Vec>) { +) -> (Vec, Vec) { // Don't use bit_repr here. It messes up sortedness. debug_assert_eq!(s_left.dtype(), s_right.dtype()); let s_left = s_left.to_physical_repr(); @@ -96,13 +95,12 @@ where let slice_left = s_left.cont_slice().unwrap(); let slice_right = s_right.cont_slice().unwrap(); - let indexes = offsets - .into_par_iter() - .map(|(offset, len)| { - let slice_left = &slice_left[offset..offset + len]; - sorted_join::inner::join(slice_left, slice_right, offset as IdxSize) - }) - .collect::>(); + let indexes = offsets.into_par_iter().map(|(offset, len)| { + let slice_left = &slice_left[offset..offset + len]; + sorted_join::inner::join(slice_left, slice_right, offset as IdxSize) + }); + let indexes = POOL.install(|| indexes.collect::>()); + let lefts = indexes.iter().map(|t| &t.0).collect::>(); let rights = indexes.iter().map(|t| &t.1).collect::>(); @@ -155,7 +153,7 @@ pub(super) fn par_sorted_merge_inner_no_nulls( } #[cfg(feature = "performant")] -fn to_left_join_ids(left_idx: Vec, right_idx: Vec>) -> LeftJoinIds { +fn to_left_join_ids(left_idx: Vec, right_idx: Vec) -> LeftJoinIds { #[cfg(feature = "chunked_ids")] { (Either::Left(left_idx), Either::Left(right_idx)) @@ -334,8 +332,11 @@ pub(super) fn sort_or_hash_left( POOL.install(|| { right.par_iter_mut().for_each(|opt_idx| { - *opt_idx = - opt_idx.map(|idx| unsafe { *reverse_idx_map.get_unchecked(idx as usize) }); + if !opt_idx.is_null_idx() { + *opt_idx = + unsafe { *reverse_idx_map.get_unchecked(opt_idx.idx() as usize) } + .into(); + } }); }); diff --git a/crates/polars-ops/src/frame/join/merge_sorted.rs b/crates/polars-ops/src/frame/join/merge_sorted.rs index ef097bb33049c..fc687aaa623f5 100644 --- a/crates/polars-ops/src/frame/join/merge_sorted.rs +++ b/crates/polars-ops/src/frame/join/merge_sorted.rs @@ -43,7 +43,7 @@ pub fn _merge_sorted_dfs( }) .collect(); - Ok(DataFrame::new_no_checks(new_columns)) + Ok(unsafe { DataFrame::new_no_checks(new_columns) }) } fn merge_series(lhs: &Series, rhs: &Series, merge_indicator: &[bool]) -> Series { diff --git a/crates/polars-ops/src/frame/join/mod.rs b/crates/polars-ops/src/frame/join/mod.rs index 4c98d4d238c5f..ccc4c72184bdc 100644 --- a/crates/polars-ops/src/frame/join/mod.rs +++ b/crates/polars-ops/src/frame/join/mod.rs @@ -9,7 +9,6 @@ mod hash_join; #[cfg(feature = "merge_sorted")] mod merge_sorted; -#[cfg(feature = "chunked_ids")] use std::borrow::Cow; use std::fmt::{Debug, Display, Formatter}; use std::hash::Hash; @@ -26,7 +25,7 @@ pub use cross_join::CrossJoin; use either::Either; #[cfg(feature = "chunked_ids")] use general::create_chunked_index_mapping; -pub use general::{_finish_join, _join_suffix_name}; +pub use general::{_coalesce_outer_join, _finish_join, _join_suffix_name}; pub use hash_join::*; use hashbrown::hash_map::{Entry, RawEntryMut}; #[cfg(feature = "merge_sorted")] @@ -34,13 +33,14 @@ pub use merge_sorted::_merge_sorted_dfs; use polars_core::hashing::{_df_rows_to_hashes_threaded_vertical, _HASHMAP_INIT_SIZE}; use polars_core::prelude::*; pub(super) use polars_core::series::IsSorted; -use polars_core::utils::{_to_physical_and_bit_repr, slice_offsets, slice_slice}; +#[allow(unused_imports)] +use polars_core::utils::slice_slice; +use polars_core::utils::{_to_physical_and_bit_repr, slice_offsets}; use polars_core::POOL; use polars_utils::hashing::BytesHash; use rayon::prelude::*; use super::IntoDf; -use crate::frame::join::general::coalesce_outer_join; pub trait DataFrameJoinOps: IntoDf { /// Generic join method. Can be used to join on multiple columns. @@ -273,8 +273,8 @@ pub trait DataFrameJoinOps: IntoDf { // Multiple keys. match args.how { JoinType::Inner => { - let left = DataFrame::new_no_checks(selected_left_physical); - let right = DataFrame::new_no_checks(selected_right_physical); + let left = unsafe { DataFrame::new_no_checks(selected_left_physical) }; + let right = unsafe { DataFrame::new_no_checks(selected_right_physical) }; let (mut left, mut right, swap) = det_hash_prone_order!(left, right); let (join_idx_left, join_idx_right) = _inner_join_multiple_keys(&mut left, &mut right, swap, args.join_nulls); @@ -298,8 +298,8 @@ pub trait DataFrameJoinOps: IntoDf { _finish_join(df_left, df_right, args.suffix.as_deref()) }, JoinType::Left => { - let mut left = DataFrame::new_no_checks(selected_left_physical); - let mut right = DataFrame::new_no_checks(selected_right_physical); + let mut left = unsafe { DataFrame::new_no_checks(selected_left_physical) }; + let mut right = unsafe { DataFrame::new_no_checks(selected_right_physical) }; if let Some((offset, len)) = args.slice { left = left.slice(offset, len); @@ -309,8 +309,8 @@ pub trait DataFrameJoinOps: IntoDf { left_df._finish_left_join(ids, &remove_selected(other, &selected_right), args) }, JoinType::Outer { .. } => { - let df_left = DataFrame::new_no_checks(selected_left_physical); - let df_right = DataFrame::new_no_checks(selected_right_physical); + let df_left = unsafe { DataFrame::new_no_checks(selected_left_physical) }; + let df_right = unsafe { DataFrame::new_no_checks(selected_right_physical) }; let (mut left, mut right, swap) = det_hash_prone_order!(df_left, df_right); let (mut join_idx_l, mut join_idx_r) = @@ -337,7 +337,7 @@ pub trait DataFrameJoinOps: IntoDf { let names_right = selected_right.iter().map(|s| s.name()).collect::>(); let out = _finish_join(df_left, df_right, args.suffix.as_deref()); if coalesce { - Ok(coalesce_outer_join( + Ok(_coalesce_outer_join( out?, &names_left, &names_right, @@ -354,8 +354,8 @@ pub trait DataFrameJoinOps: IntoDf { ), #[cfg(feature = "semi_anti_join")] JoinType::Anti | JoinType::Semi => { - let mut left = DataFrame::new_no_checks(selected_left_physical); - let mut right = DataFrame::new_no_checks(selected_right_physical); + let mut left = unsafe { DataFrame::new_no_checks(selected_left_physical) }; + let mut right = unsafe { DataFrame::new_no_checks(selected_right_physical) }; let idx = if matches!(args.how, JoinType::Anti) { _left_anti_multiple_keys(&mut left, &mut right, args.join_nulls) diff --git a/crates/polars-ops/src/frame/pivot/mod.rs b/crates/polars-ops/src/frame/pivot/mod.rs index 3760eb3e61a7c..cec9ddd01cdb0 100644 --- a/crates/polars-ops/src/frame/pivot/mod.rs +++ b/crates/polars-ops/src/frame/pivot/mod.rs @@ -82,27 +82,23 @@ fn restore_logical_type(s: &Series, logical_type: &DataType) -> Series { /// # Note /// Polars'/arrow memory is not ideal for transposing operations like pivots. /// If you have a relatively large table, consider using a group_by over a pivot. -pub fn pivot( +pub fn pivot( pivot_df: &DataFrame, - values: I0, - index: I1, - columns: I2, + index: I0, + columns: I1, + values: Option, sort_columns: bool, agg_fn: Option, separator: Option<&str>, ) -> PolarsResult where I0: IntoIterator, - S0: AsRef, I1: IntoIterator, - S1: AsRef, I2: IntoIterator, + S0: AsRef, + S1: AsRef, S2: AsRef, { - let values = values - .into_iter() - .map(|s| s.as_ref().to_string()) - .collect::>(); let index = index .into_iter() .map(|s| s.as_ref().to_string()) @@ -111,11 +107,12 @@ where .into_iter() .map(|s| s.as_ref().to_string()) .collect::>(); + let values = get_values_columns(pivot_df, &index, &columns, values); pivot_impl( pivot_df, - &values, &index, &columns, + &values, agg_fn, sort_columns, false, @@ -128,27 +125,23 @@ where /// # Note /// Polars'/arrow memory is not ideal for transposing operations like pivots. /// If you have a relatively large table, consider using a group_by over a pivot. -pub fn pivot_stable( +pub fn pivot_stable( pivot_df: &DataFrame, - values: I0, - index: I1, - columns: I2, + index: I0, + columns: I1, + values: Option, sort_columns: bool, agg_fn: Option, separator: Option<&str>, ) -> PolarsResult where I0: IntoIterator, - S0: AsRef, I1: IntoIterator, - S1: AsRef, I2: IntoIterator, + S0: AsRef, + S1: AsRef, S2: AsRef, { - let values = values - .into_iter() - .map(|s| s.as_ref().to_string()) - .collect::>(); let index = index .into_iter() .map(|s| s.as_ref().to_string()) @@ -157,12 +150,12 @@ where .into_iter() .map(|s| s.as_ref().to_string()) .collect::>(); - + let values = get_values_columns(pivot_df, &index, &columns, values); pivot_impl( pivot_df, - &values, &index, &columns, + &values, agg_fn, sort_columns, true, @@ -170,16 +163,41 @@ where ) } +/// Determine `values` columns, which is optional in `pivot` calls. +/// +/// If not specified (i.e. is `None`), use all remaining columns in the +/// `DataFrame` after `index` and `columns` have been excluded. +fn get_values_columns( + df: &DataFrame, + index: &[String], + columns: &[String], + values: Option, +) -> Vec +where + I: IntoIterator, + S: AsRef, +{ + match values { + Some(v) => v.into_iter().map(|s| s.as_ref().to_string()).collect(), + None => df + .get_column_names() + .into_iter() + .map(|c| c.to_string()) + .filter(|c| !(index.contains(c) | columns.contains(c))) + .collect(), + } +} + #[allow(clippy::too_many_arguments)] fn pivot_impl( pivot_df: &DataFrame, - // these columns will be aggregated in the nested group_by - values: &[String], // keys of the first group_by operation index: &[String], // these columns will be used for a nested group_by // the rows of this nested group_by will be pivoted as header column values columns: &[String], + // these columns will be aggregated in the nested group_by + values: &[String], // aggregation function agg_fn: Option, sort_columns: bool, @@ -206,9 +224,9 @@ fn pivot_impl( let pivot_df = unsafe { binding.with_column_unchecked(columns_struct) }; pivot_impl_single_column( pivot_df, + index, &column, values, - index, agg_fn, sort_columns, separator, @@ -216,9 +234,9 @@ fn pivot_impl( } else { pivot_impl_single_column( pivot_df, + index, unsafe { columns.get_unchecked(0) }, values, - index, agg_fn, sort_columns, separator, @@ -228,9 +246,9 @@ fn pivot_impl( fn pivot_impl_single_column( pivot_df: &DataFrame, + index: &[String], column: &str, values: &[String], - index: &[String], agg_fn: Option, sort_columns: bool, separator: Option<&str>, @@ -274,7 +292,7 @@ fn pivot_impl_single_column( let name = expr.root_name()?; let mut value_col = value_col.clone(); value_col.rename(name); - let tmp_df = DataFrame::new_no_checks(vec![value_col]); + let tmp_df = value_col.into_frame(); let mut aggregated = expr.evaluate(&tmp_df, &groups)?; aggregated.rename(value_col_name); aggregated @@ -341,5 +359,7 @@ fn pivot_impl_single_column( Ok(()) }); out?; - DataFrame::new_no_length_checks(final_cols) + + // SAFETY: length has already been checked. + unsafe { DataFrame::new_no_length_checks(final_cols) } } diff --git a/crates/polars-ops/src/frame/pivot/positioning.rs b/crates/polars-ops/src/frame/pivot/positioning.rs index d494c80f6bdfe..5ad0b32f101d9 100644 --- a/crates/polars-ops/src/frame/pivot/positioning.rs +++ b/crates/polars-ops/src/frame/pivot/positioning.rs @@ -3,6 +3,7 @@ use std::hash::Hash; use arrow::legacy::trusted_len::TrustedLenPush; use polars_core::prelude::*; use polars_utils::sync::SyncPtr; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; use super::*; @@ -175,23 +176,23 @@ where fn compute_col_idx_numeric(column_agg_physical: &ChunkedArray) -> Vec where T: PolarsNumericType, - T::Native: Hash + Eq, + T::Native: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Hash + Eq, { let mut col_to_idx = PlHashMap::with_capacity(HASHMAP_INIT_SIZE); let mut idx = 0 as IdxSize; let mut out = Vec::with_capacity(column_agg_physical.len()); - for arr in column_agg_physical.downcast_iter() { - for opt_v in arr.into_iter() { - let idx = *col_to_idx.entry(opt_v).or_insert_with(|| { - let old_idx = idx; - idx += 1; - old_idx - }); - // SAFETY: - // we pre-allocated - unsafe { out.push_unchecked(idx) }; - } + for opt_v in column_agg_physical.iter() { + let opt_v = opt_v.to_total_ord(); + let idx = *col_to_idx.entry(opt_v).or_insert_with(|| { + let old_idx = idx; + idx += 1; + old_idx + }); + // SAFETY: + // we pre-allocated + unsafe { out.push_unchecked(idx) }; } out } @@ -232,14 +233,22 @@ pub(super) fn compute_col_idx( use DataType::*; let col_locations = match column_agg_physical.dtype() { - Int32 | UInt32 | Float32 => { + Int32 | UInt32 => { let ca = column_agg_physical.bit_repr_small(); compute_col_idx_numeric(&ca) }, - Int64 | UInt64 | Float64 => { + Int64 | UInt64 => { let ca = column_agg_physical.bit_repr_large(); compute_col_idx_numeric(&ca) }, + Float64 => { + let ca: &ChunkedArray = column_agg_physical.as_ref().as_ref().as_ref(); + compute_col_idx_numeric(ca) + }, + Float32 => { + let ca: &ChunkedArray = column_agg_physical.as_ref().as_ref().as_ref(); + compute_col_idx_numeric(ca) + }, Struct(_) => { let ca = column_agg_physical.struct_().unwrap(); let ca = ca.rows_encode()?; @@ -286,7 +295,8 @@ fn compute_row_index<'a, T>( ) -> (Vec, usize, Option>) where T: PolarsDataType, - T::Physical<'a>: Hash + Eq + Copy, + T::Physical<'a>: TotalHash + TotalEq + Copy + ToTotalOrd, + > as ToTotalOrd>::TotalOrdItem: Hash + Eq, ChunkedArray: FromIterator>>, ChunkedArray: IntoSeries, { @@ -295,26 +305,25 @@ where let mut idx = 0 as IdxSize; let mut row_locations = Vec::with_capacity(index_agg_physical.len()); - for arr in index_agg_physical.downcast_iter() { - for opt_v in arr.iter() { - let idx = *row_to_idx.entry(opt_v).or_insert_with(|| { - let old_idx = idx; - idx += 1; - old_idx - }); + for opt_v in index_agg_physical.iter() { + let opt_v = opt_v.to_total_ord(); + let idx = *row_to_idx.entry(opt_v).or_insert_with(|| { + let old_idx = idx; + idx += 1; + old_idx + }); - // SAFETY: - // we pre-allocated - unsafe { - row_locations.push_unchecked(idx); - } + // SAFETY: + // we pre-allocated + unsafe { + row_locations.push_unchecked(idx); } } let row_index = match count { 0 => { let mut s = row_to_idx .into_iter() - .map(|(k, _)| k) + .map(|(k, _)| Option::>::peel_total_ord(k)) .collect::>() .into_series(); s.rename(&index[0]); @@ -386,14 +395,22 @@ pub(super) fn compute_row_idx( use DataType::*; match index_agg_physical.dtype() { - Int32 | UInt32 | Float32 => { + Int32 | UInt32 => { let ca = index_agg_physical.bit_repr_small(); compute_row_index(index, &ca, count, index_s.dtype()) }, - Int64 | UInt64 | Float64 => { + Int64 | UInt64 => { let ca = index_agg_physical.bit_repr_large(); compute_row_index(index, &ca, count, index_s.dtype()) }, + Float64 => { + let ca: &ChunkedArray = index_agg_physical.as_ref().as_ref().as_ref(); + compute_row_index(index, ca, count, index_s.dtype()) + }, + Float32 => { + let ca: &ChunkedArray = index_agg_physical.as_ref().as_ref().as_ref(); + compute_row_index(index, ca, count, index_s.dtype()) + }, Boolean => { let ca = index_agg_physical.bool().unwrap(); compute_row_index(index, ca, count, index_s.dtype()) diff --git a/crates/polars-ops/src/series/ops/approx_unique.rs b/crates/polars-ops/src/series/ops/approx_unique.rs index fe5d703723956..31093e06b77ac 100644 --- a/crates/polars-ops/src/series/ops/approx_unique.rs +++ b/crates/polars-ops/src/series/ops/approx_unique.rs @@ -2,6 +2,7 @@ use std::hash::Hash; use polars_core::prelude::*; use polars_core::with_match_physical_integer_polars_type; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; #[cfg(feature = "approx_unique")] use crate::series::ops::approx_algo::HyperLogLog; @@ -9,10 +10,11 @@ use crate::series::ops::approx_algo::HyperLogLog; fn approx_n_unique_ca<'a, T>(ca: &'a ChunkedArray) -> PolarsResult where T: PolarsDataType, - T::Physical<'a>: Hash + Eq, + T::Physical<'a>: TotalHash + TotalEq + Copy + ToTotalOrd, + > as ToTotalOrd>::TotalOrdItem: Hash + Eq, { let mut hllp = HyperLogLog::new(); - ca.iter().for_each(|item| hllp.add(&item)); + ca.iter().for_each(|item| hllp.add(&item.to_total_ord())); let c = hllp.count() as IdxSize; Ok(Series::new(ca.name(), &[c])) @@ -28,8 +30,12 @@ fn dispatcher(s: &Series) -> PolarsResult { let ca = s.str().unwrap().as_binary(); approx_n_unique_ca(&ca) }, - Float32 => approx_n_unique_ca(&s.bit_repr_small()), - Float64 => approx_n_unique_ca(&s.bit_repr_large()), + Float32 => approx_n_unique_ca(AsRef::>::as_ref( + s.as_ref().as_ref(), + )), + Float64 => approx_n_unique_ca(AsRef::>::as_ref( + s.as_ref().as_ref(), + )), dt if dt.is_numeric() => { with_match_physical_integer_polars_type!(s.dtype(), |$T| { let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); diff --git a/crates/polars-ops/src/series/ops/arg_min_max.rs b/crates/polars-ops/src/series/ops/arg_min_max.rs index 563d9c96f4300..cd1d7d5f78162 100644 --- a/crates/polars-ops/src/series/ops/arg_min_max.rs +++ b/crates/polars-ops/src/series/ops/arg_min_max.rs @@ -1,6 +1,9 @@ use argminmax::ArgMinMax; use arrow::array::Array; use arrow::legacy::bit_util::*; +use polars_core::chunked_array::ops::float_sorted_arg_max::{ + float_arg_max_sorted_ascending, float_arg_max_sorted_descending, +}; use polars_core::series::IsSorted; use polars_core::with_match_physical_numeric_polars_type; @@ -22,7 +25,7 @@ impl ArgAgg for Series { #[cfg(feature = "dtype-categorical")] Categorical(_, _) => { let ca = self.categorical().unwrap(); - if ca.is_empty() || ca.null_count() == ca.len() { + if ca.null_count() == ca.len() { return None; } if ca.uses_lexical_ordering() { @@ -69,7 +72,7 @@ impl ArgAgg for Series { #[cfg(feature = "dtype-categorical")] Categorical(_, _) => { let ca = self.categorical().unwrap(); - if ca.is_empty() || ca.null_count() == ca.len() { + if ca.null_count() == ca.len() { return None; } if ca.uses_lexical_ordering() { @@ -114,8 +117,10 @@ where T: PolarsNumericType, for<'b> &'b [T::Native]: ArgMinMax, { - if ca.is_empty() || ca.null_count() == ca.len() { + if ca.null_count() == ca.len() { None + } else if T::get_dtype().is_float() && !matches!(ca.is_sorted_flag(), IsSorted::Not) { + arg_max_float_sorted(ca) } else if let Ok(vals) = ca.cont_slice() { arg_max_numeric_slice(vals, ca.is_sorted_flag()) } else { @@ -128,7 +133,7 @@ where T: PolarsNumericType, for<'b> &'b [T::Native]: ArgMinMax, { - if ca.is_empty() || ca.null_count() == ca.len() { + if ca.null_count() == ca.len() { None } else if let Ok(vals) = ca.cont_slice() { arg_min_numeric_slice(vals, ca.is_sorted_flag()) @@ -138,7 +143,7 @@ where } pub(crate) fn arg_max_bool(ca: &BooleanChunked) -> Option { - if ca.is_empty() || ca.null_count() == ca.len() { + if ca.null_count() == ca.len() { None } // don't check for any, that on itself is already an argmax search @@ -162,8 +167,23 @@ pub(crate) fn arg_max_bool(ca: &BooleanChunked) -> Option { } } +/// # Safety +/// `ca` has a float dtype, has at least one non-null value and is sorted. +fn arg_max_float_sorted(ca: &ChunkedArray) -> Option +where + T: PolarsNumericType, +{ + let out = match ca.is_sorted_flag() { + IsSorted::Ascending => float_arg_max_sorted_ascending(ca), + IsSorted::Descending => float_arg_max_sorted_descending(ca), + _ => unreachable!(), + }; + + Some(out) +} + fn arg_min_bool(ca: &BooleanChunked) -> Option { - if ca.is_empty() || ca.null_count() == ca.len() { + if ca.null_count() == ca.len() { None } else if ca.null_count() == 0 && ca.chunks().len() == 1 { let arr = ca.downcast_iter().next().unwrap(); @@ -186,7 +206,7 @@ fn arg_min_bool(ca: &BooleanChunked) -> Option { } fn arg_min_str(ca: &StringChunked) -> Option { - if ca.is_empty() || ca.null_count() == ca.len() { + if ca.null_count() == ca.len() { return None; } match ca.is_sorted_flag() { @@ -202,7 +222,7 @@ fn arg_min_str(ca: &StringChunked) -> Option { } fn arg_max_str(ca: &StringChunked) -> Option { - if ca.is_empty() || ca.null_count() == ca.len() { + if ca.null_count() == ca.len() { return None; } match ca.is_sorted_flag() { diff --git a/crates/polars-ops/src/series/ops/cum_agg.rs b/crates/polars-ops/src/series/ops/cum_agg.rs index e47b3c4c84278..03aed306f1189 100644 --- a/crates/polars-ops/src/series/ops/cum_agg.rs +++ b/crates/polars-ops/src/series/ops/cum_agg.rs @@ -1,8 +1,8 @@ -use std::iter::FromIterator; use std::ops::{Add, AddAssign, Mul}; use num_traits::{Bounded, One, Zero}; use polars_core::prelude::*; +use polars_core::series::IsSorted; use polars_core::utils::{CustomIterTools, NoNull}; use polars_core::with_match_physical_numeric_polars_type; @@ -208,33 +208,37 @@ pub fn cum_max(s: &Series, reverse: bool) -> PolarsResult { } pub fn cum_count(s: &Series, reverse: bool) -> PolarsResult { - // Fast paths for no nulls - if s.null_count() == 0 { - let out = cum_count_no_nulls(s.name(), s.len(), reverse); - return Ok(out); - } - - let ca = s.is_not_null(); - let out: IdxCa = if reverse { - let mut count = (s.len() - s.null_count()) as IdxSize; - let mut prev = false; - ca.apply_values_generic(|v: bool| { - if prev { - count -= 1; - } - prev = v; - count - }) + let mut out = if s.null_count() == 0 { + // Fast paths for no nulls + cum_count_no_nulls(s.name(), s.len(), reverse) } else { - let mut count = 0 as IdxSize; - ca.apply_values_generic(|v: bool| { - if v { - count += 1; - } - count - }) + let ca = s.is_not_null(); + let out: IdxCa = if reverse { + let mut count = (s.len() - s.null_count()) as IdxSize; + let mut prev = false; + ca.apply_values_generic(|v: bool| { + if prev { + count -= 1; + } + prev = v; + count + }) + } else { + let mut count = 0 as IdxSize; + ca.apply_values_generic(|v: bool| { + if v { + count += 1; + } + count + }) + }; + + out.into() }; - Ok(out.into()) + + out.set_sorted_flag([IsSorted::Ascending, IsSorted::Descending][reverse as usize]); + + Ok(out) } fn cum_count_no_nulls(name: &str, len: usize, reverse: bool) -> Series { diff --git a/crates/polars-ops/src/series/ops/cut.rs b/crates/polars-ops/src/series/ops/cut.rs index df9ee97c4b327..b7e87a23d8a89 100644 --- a/crates/polars-ops/src/series/ops/cut.rs +++ b/crates/polars-ops/src/series/ops/cut.rs @@ -1,6 +1,3 @@ -use std::cmp::PartialOrd; -use std::iter::once; - use polars_core::prelude::*; fn map_cats( @@ -82,8 +79,8 @@ pub fn cut( polars_ensure!(ll.len() == sorted_breaks.len() + 1, ShapeMismatch: "Provide nbreaks + 1 labels"); ll }, - None => (once(&f64::NEG_INFINITY).chain(sorted_breaks.iter())) - .zip(sorted_breaks.iter().chain(once(&f64::INFINITY))) + None => (std::iter::once(&f64::NEG_INFINITY).chain(sorted_breaks.iter())) + .zip(sorted_breaks.iter().chain(std::iter::once(&f64::INFINITY))) .map(|v| { if left_closed { format!("[{}, {})", v.0, v.1) diff --git a/crates/polars-ops/src/series/ops/ewm.rs b/crates/polars-ops/src/series/ops/ewm.rs index cbc16b6abc37d..22b99a04a8921 100644 --- a/crates/polars-ops/src/series/ops/ewm.rs +++ b/crates/polars-ops/src/series/ops/ewm.rs @@ -1,6 +1,3 @@ -use std::convert::TryFrom; - -use arrow::array::ArrayRef; pub use arrow::legacy::kernels::ewm::EWMOptions; use arrow::legacy::kernels::ewm::{ ewm_mean as kernel_ewm_mean, ewm_std as kernel_ewm_std, ewm_var as kernel_ewm_var, diff --git a/crates/polars-ops/src/series/ops/floor_divide.rs b/crates/polars-ops/src/series/ops/floor_divide.rs index 68468bf887b43..85d8750a53136 100644 --- a/crates/polars-ops/src/series/ops/floor_divide.rs +++ b/crates/polars-ops/src/series/ops/floor_divide.rs @@ -1,6 +1,5 @@ use polars_compute::arithmetic::ArithmeticKernel; use polars_core::chunked_array::ops::arity::apply_binary_kernel_broadcast; -use polars_core::datatypes::PolarsNumericType; use polars_core::prelude::*; #[cfg(feature = "dtype-struct")] use polars_core::series::arithmetic::_struct_arithmetic; diff --git a/crates/polars-ops/src/series/ops/horizontal.rs b/crates/polars-ops/src/series/ops/horizontal.rs index 003589657158c..fd4dd76d24348 100644 --- a/crates/polars-ops/src/series/ops/horizontal.rs +++ b/crates/polars-ops/src/series/ops/horizontal.rs @@ -42,25 +42,25 @@ pub fn all_horizontal(s: &[Series]) -> PolarsResult { } pub fn max_horizontal(s: &[Series]) -> PolarsResult> { - let df = DataFrame::new_no_checks(Vec::from(s)); + let df = unsafe { DataFrame::new_no_checks(Vec::from(s)) }; df.max_horizontal() .map(|opt_s| opt_s.map(|res| res.with_name(s[0].name()))) } pub fn min_horizontal(s: &[Series]) -> PolarsResult> { - let df = DataFrame::new_no_checks(Vec::from(s)); + let df = unsafe { DataFrame::new_no_checks(Vec::from(s)) }; df.min_horizontal() .map(|opt_s| opt_s.map(|res| res.with_name(s[0].name()))) } pub fn sum_horizontal(s: &[Series]) -> PolarsResult> { - let df = DataFrame::new_no_checks(Vec::from(s)); + let df = unsafe { DataFrame::new_no_checks(Vec::from(s)) }; df.sum_horizontal(NullStrategy::Ignore) .map(|opt_s| opt_s.map(|res| res.with_name(s[0].name()))) } pub fn mean_horizontal(s: &[Series]) -> PolarsResult> { - let df = DataFrame::new_no_checks(Vec::from(s)); + let df = unsafe { DataFrame::new_no_checks(Vec::from(s)) }; df.mean_horizontal(NullStrategy::Ignore) .map(|opt_s| opt_s.map(|res| res.with_name(s[0].name()))) } diff --git a/crates/polars-ops/src/series/ops/index.rs b/crates/polars-ops/src/series/ops/index.rs index fc2823e81fc06..368780566d8fb 100644 --- a/crates/polars-ops/src/series/ops/index.rs +++ b/crates/polars-ops/src/series/ops/index.rs @@ -1,5 +1,5 @@ use num_traits::{Signed, Zero}; -use polars_core::error::{polars_bail, polars_ensure, PolarsResult}; +use polars_core::error::{polars_ensure, PolarsResult}; use polars_core::prelude::{ChunkedArray, DataType, IdxCa, PolarsIntegerType, Series, IDX_DTYPE}; use polars_utils::index::ToIdx; diff --git a/crates/polars-ops/src/series/ops/is_first_distinct.rs b/crates/polars-ops/src/series/ops/is_first_distinct.rs index 178c80bb980d9..b75ae23dba1f7 100644 --- a/crates/polars-ops/src/series/ops/is_first_distinct.rs +++ b/crates/polars-ops/src/series/ops/is_first_distinct.rs @@ -5,16 +5,18 @@ use arrow::bitmap::MutableBitmap; use arrow::legacy::bit_util::*; use arrow::legacy::utils::CustomIterTools; use polars_core::prelude::*; -use polars_core::with_match_physical_integer_polars_type; +use polars_core::with_match_physical_numeric_polars_type; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; fn is_first_distinct_numeric(ca: &ChunkedArray) -> BooleanChunked where T: PolarsNumericType, - T::Native: Hash + Eq, + T::Native: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Hash + Eq, { let mut unique = PlHashSet::new(); let chunks = ca.downcast_iter().map(|arr| -> BooleanArray { arr.into_iter() - .map(|opt_v| unique.insert(opt_v)) + .map(|opt_v| unique.insert(opt_v.to_total_ord())) .collect_trusted() }); @@ -126,16 +128,8 @@ pub fn is_first_distinct(s: &Series) -> PolarsResult { let s = s.cast(&Binary).unwrap(); return is_first_distinct(&s); }, - Float32 => { - let ca = s.bit_repr_small(); - is_first_distinct_numeric(&ca) - }, - Float64 => { - let ca = s.bit_repr_large(); - is_first_distinct_numeric(&ca) - }, dt if dt.is_numeric() => { - with_match_physical_integer_polars_type!(s.dtype(), |$T| { + with_match_physical_numeric_polars_type!(s.dtype(), |$T| { let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); is_first_distinct_numeric(ca) }) diff --git a/crates/polars-ops/src/series/ops/is_in.rs b/crates/polars-ops/src/series/ops/is_in.rs index ca8b5d559233f..13dace1c4b50c 100644 --- a/crates/polars-ops/src/series/ops/is_in.rs +++ b/crates/polars-ops/src/series/ops/is_in.rs @@ -1,11 +1,11 @@ -#[cfg(feature = "dtype-categorical")] -use polars_core::apply_amortized_generic_list_or_array; +use std::hash::Hash; + use polars_core::prelude::*; use polars_core::utils::{try_get_supertype, CustomIterTools}; use polars_core::with_match_physical_numeric_polars_type; #[cfg(feature = "dtype-categorical")] use polars_utils::iter::EnumerateIdxTrait; -use polars_utils::total_ord::{TotalEq, TotalHash, TotalOrdWrap}; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; fn is_in_helper_ca<'a, T>( ca: &'a ChunkedArray, @@ -13,25 +13,27 @@ fn is_in_helper_ca<'a, T>( ) -> PolarsResult where T: PolarsDataType, - T::Physical<'a>: TotalHash + TotalEq + Copy, + T::Physical<'a>: TotalHash + TotalEq + ToTotalOrd + Copy, + as ToTotalOrd>::TotalOrdItem: Hash + Eq + Copy, { let mut set = PlHashSet::with_capacity(other.len()); other.downcast_iter().for_each(|iter| { iter.iter().for_each(|opt_val| { if let Some(v) = opt_val { - set.insert(TotalOrdWrap(v)); + set.insert(v.to_total_ord()); } }) }); Ok(ca - .apply_values_generic(|val| set.contains(&TotalOrdWrap(val))) + .apply_values_generic(|val| set.contains(&val.to_total_ord())) .with_name(ca.name())) } fn is_in_helper<'a, T>(ca: &'a ChunkedArray, other: &Series) -> PolarsResult where T: PolarsDataType, - T::Physical<'a>: TotalHash + TotalEq + Copy, + T::Physical<'a>: TotalHash + TotalEq + Copy + ToTotalOrd, + as ToTotalOrd>::TotalOrdItem: Hash + Eq + Copy, { let other = ca.unpack_series_matching_type(other)?; is_in_helper_ca(ca, other) @@ -112,7 +114,8 @@ where fn is_in_numeric(ca_in: &ChunkedArray, other: &Series) -> PolarsResult where T: PolarsNumericType, - T::Native: TotalHash + TotalEq, + T::Native: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Hash + Eq + Copy, { // We check implicitly cast to supertype here match other.dtype() { @@ -149,47 +152,57 @@ where } #[cfg(feature = "dtype-categorical")] -fn is_in_string_inner_categorical( +fn is_in_string_list_categorical( ca_in: &StringChunked, other: &Series, rev_map: &Arc, ) -> PolarsResult { - let opt_val = ca_in.get(0); - match opt_val { - None => { - let out = - apply_amortized_generic_list_or_array!(other, apply_amortized_generic, |opt_s| { - opt_s.map(|s| Some(s.as_ref().null_count() > 0) == Some(true)) - }); - Ok(out.with_name(ca_in.name())) - }, - Some(value) => { - match rev_map.find(value) { - // all false - None => Ok(BooleanChunked::full(ca_in.name(), false, other.len())), - Some(idx) => { - let out = apply_amortized_generic_list_or_array!( - other, - apply_amortized_generic, - |opt_s| { - Some( - opt_s.map(|s| { - let s = s.as_ref().to_physical_repr(); - let ca = s.as_ref().u32().unwrap(); - if ca.null_count() == 0 { - ca.into_no_null_iter().any(|a| a == idx) - } else { - ca.iter().any(|a| a == Some(idx)) - } - }) == Some(true), - ) - } - ); - Ok(out.with_name(ca_in.name())) - }, - } - }, - } + let mut ca = if ca_in.len() == 1 && other.len() != 1 { + let opt_val = ca_in.get(0); + match opt_val.map(|val| rev_map.find(val)) { + None => other.list()?.apply_amortized_generic(|opt_s| { + { + opt_s.map(|s| s.as_ref().null_count() > 0) + } + }), + Some(None) => other + .list()? + .apply_amortized_generic(|opt_s| opt_s.map(|_| false)), + Some(Some(idx)) => other.list()?.apply_amortized_generic(|opt_s| { + opt_s.map(|s| { + let s = s.as_ref().to_physical_repr(); + let ca = s.as_ref().u32().unwrap(); + if ca.null_count() == 0 { + ca.into_no_null_iter().any(|a| a == idx) + } else { + ca.iter().any(|a| a == Some(idx)) + } + }) + }), + } + } else { + polars_ensure!(ca_in.len() == other.len(), ComputeError: "shapes don't match: expected {} elements in 'is_in' comparison, got {}", ca_in.len(), other.len()); + // SAFETY: unstable series never lives longer than the iterator. + unsafe { + ca_in + .iter() + .zip(other.list()?.amortized_iter()) + .map(|(opt_val, series)| match (opt_val, series) { + (opt_val, Some(series)) => match opt_val.map(|val| rev_map.find(val)) { + None => Some(series.as_ref().null_count() > 0), + Some(None) => Some(false), + Some(Some(idx)) => { + let ca = series.as_ref().categorical().unwrap(); + Some(ca.physical().iter().any(|el| el == Some(idx))) + }, + }, + _ => None, + }) + .collect() + } + }; + ca.rename(ca_in.name()); + Ok(ca) } fn is_in_string(ca_in: &StringChunked, other: &Series) -> PolarsResult { @@ -200,18 +213,7 @@ fn is_in_string(ca_in: &StringChunked, other: &Series) -> PolarsResult { - is_in_string_inner_categorical(ca_in, other, rev_map) - }, - _ => unreachable!(), - } - }, - #[cfg(all(feature = "dtype-categorical", feature = "dtype-array"))] - DataType::Array(dt, _) - if matches!(&**dt, DataType::Categorical(_, _) | DataType::Enum(_, _)) => - { - match &**dt { - DataType::Enum(Some(rev_map), _) | DataType::Categorical(Some(rev_map), _) => { - is_in_string_inner_categorical(ca_in, other, rev_map) + is_in_string_list_categorical(ca_in, other, rev_map) }, _ => unreachable!(), } @@ -232,6 +234,10 @@ fn is_in_string(ca_in: &StringChunked, other: &Series) -> PolarsResult { is_in_binary(&ca_in.as_binary(), &other.cast(&DataType::Binary).unwrap()) }, + #[cfg(feature = "dtype-categorical")] + DataType::Enum(_, _) | DataType::Categorical(_, _) => { + is_in_string_categorical(ca_in, other.categorical().unwrap()) + }, _ => polars_bail!(opq = is_in, ca_in.dtype(), other.dtype()), } } @@ -556,6 +562,24 @@ fn is_in_struct(ca_in: &StructChunked, other: &Series) -> PolarsResult PolarsResult { + // In case of fast unique, we can directly use the categories. Otherwise we need to + // first get the unique physicals + let categories = StringChunked::with_chunk("", other.get_rev_map().get_categories().clone()); + let other = if other._can_fast_unique() { + categories + } else { + let s = other.physical().unique()?.cast(&IDX_DTYPE)?; + // SAFETY: Invariant of categorical means indices are in bound + unsafe { categories.take_unchecked(s.idx()?) } + }; + is_in_helper_ca(&ca_in.as_binary(), &other.as_binary()) +} + #[cfg(feature = "dtype-categorical")] fn is_in_cat(ca_in: &CategoricalChunked, other: &Series) -> PolarsResult { match other.dtype() { @@ -580,7 +604,7 @@ fn is_in_cat(ca_in: &CategoricalChunked, other: &Series) -> PolarsResult PolarsResult PolarsResult + { + is_in_cat_list(ca_in, other) + }, + _ => polars_bail!(opq = is_in, ca_in.dtype(), other.dtype()), } } +#[cfg(feature = "dtype-categorical")] +fn is_in_cat_list(ca_in: &CategoricalChunked, other: &Series) -> PolarsResult { + let list_chunked = other.list()?; + + let mut ca: BooleanChunked = if ca_in.len() == 1 && other.len() != 1 { + let (DataType::Categorical(Some(rev_map), _) | DataType::Enum(Some(rev_map), _)) = + list_chunked.inner_dtype() + else { + unreachable!(); + }; + + let idx = ca_in.physical().get(0); + let new_phys = idx + .map(|idx| ca_in.get_rev_map().get(idx)) + .map(|s| rev_map.find(s)); + + match new_phys { + None => list_chunked + .apply_amortized_generic(|opt_s| opt_s.map(|s| s.as_ref().null_count() > 0)), + Some(None) => list_chunked.apply_amortized_generic(|opt_s| opt_s.map(|_| false)), + Some(Some(idx)) => list_chunked.apply_amortized_generic(|opt_s| { + opt_s.map(|s| { + let ca = s.as_ref().categorical().unwrap(); + ca.physical().iter().any(|a| a == Some(idx)) + }) + }), + } + } else { + polars_ensure!(ca_in.len() == other.len(), ComputeError: "shapes don't match: expected {} elements in 'is_in' comparison, got {}", ca_in.len(), other.len()); + let list_chunked_inner = list_chunked.get_inner(); + let inner_cat = list_chunked_inner.categorical()?; + // Make physicals compatible of ca_in with those of the list + let (_, ca_in) = make_categoricals_compatible(inner_cat, ca_in)?; + + // SAFETY: unstable series never lives longer than the iterator. + unsafe { + ca_in + .physical() + .iter() + .zip(list_chunked.amortized_iter()) + .map(|(value, series)| match (value, series) { + (val, Some(series)) => { + let ca = series.as_ref().categorical().unwrap(); + Some(ca.physical().iter().any(|a| a == val)) + }, + _ => None, + }) + .collect_trusted() + } + }; + ca.rename(ca_in.name()); + Ok(ca) +} + pub fn is_in(s: &Series, other: &Series) -> PolarsResult { match s.dtype() { #[cfg(feature = "dtype-categorical")] diff --git a/crates/polars-ops/src/series/ops/is_last_distinct.rs b/crates/polars-ops/src/series/ops/is_last_distinct.rs index 84fe94a5c0022..40d57d4381511 100644 --- a/crates/polars-ops/src/series/ops/is_last_distinct.rs +++ b/crates/polars-ops/src/series/ops/is_last_distinct.rs @@ -5,7 +5,8 @@ use arrow::bitmap::MutableBitmap; use arrow::legacy::utils::CustomIterTools; use polars_core::prelude::*; use polars_core::utils::NoNull; -use polars_core::with_match_physical_integer_polars_type; +use polars_core::with_match_physical_numeric_polars_type; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; pub fn is_last_distinct(s: &Series) -> PolarsResult { // fast path. @@ -31,16 +32,8 @@ pub fn is_last_distinct(s: &Series) -> PolarsResult { let s = s.cast(&Binary).unwrap(); return is_last_distinct(&s); }, - Float32 => { - let ca = s.bit_repr_small(); - is_last_distinct_numeric(&ca) - }, - Float64 => { - let ca = s.bit_repr_large(); - is_last_distinct_numeric(&ca) - }, dt if dt.is_numeric() => { - with_match_physical_integer_polars_type!(s.dtype(), |$T| { + with_match_physical_numeric_polars_type!(s.dtype(), |$T| { let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); is_last_distinct_numeric(ca) }) @@ -131,7 +124,8 @@ fn is_last_distinct_bin(ca: &BinaryChunked) -> BooleanChunked { fn is_last_distinct_numeric(ca: &ChunkedArray) -> BooleanChunked where T: PolarsNumericType, - T::Native: Hash + Eq, + T::Native: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Hash + Eq, { let ca = ca.rechunk(); let arr = ca.downcast_iter().next().unwrap(); @@ -139,7 +133,7 @@ where let mut new_ca: BooleanChunked = arr .into_iter() .rev() - .map(|opt_v| unique.insert(opt_v)) + .map(|opt_v| unique.insert(opt_v.to_total_ord())) .collect_reversed::>() .into_inner(); new_ca.rename(ca.name()); diff --git a/crates/polars-ops/src/series/ops/is_unique.rs b/crates/polars-ops/src/series/ops/is_unique.rs index 3e3f09f5b3af3..265e8736b35e6 100644 --- a/crates/polars-ops/src/series/ops/is_unique.rs +++ b/crates/polars-ops/src/series/ops/is_unique.rs @@ -1,14 +1,17 @@ +use std::hash::Hash; + use arrow::array::BooleanArray; use arrow::bitmap::MutableBitmap; use polars_core::prelude::*; use polars_core::with_match_physical_integer_polars_type; -use polars_utils::total_ord::{TotalEq, TotalHash, TotalOrdWrap}; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; // If invert is true then this is an `is_duplicated`. fn is_unique_ca<'a, T>(ca: &'a ChunkedArray, invert: bool) -> BooleanChunked where T: PolarsDataType, - T::Physical<'a>: TotalHash + TotalEq, + T::Physical<'a>: TotalHash + TotalEq + Copy + ToTotalOrd, + > as ToTotalOrd>::TotalOrdItem: Hash + Eq, { let len = ca.len(); let mut idx_key = PlHashMap::new(); @@ -17,7 +20,7 @@ where // just toggle a boolean that's false if a group has multiple entries. ca.iter().enumerate().for_each(|(idx, key)| { idx_key - .entry(TotalOrdWrap(key)) + .entry(key.to_total_ord()) .and_modify(|v: &mut (IdxSize, bool)| v.1 = false) .or_insert((idx as IdxSize, true)); }); diff --git a/crates/polars-ops/src/series/ops/mod.rs b/crates/polars-ops/src/series/ops/mod.rs index 8a64afbd9fbcb..9670a296e95b4 100644 --- a/crates/polars-ops/src/series/ops/mod.rs +++ b/crates/polars-ops/src/series/ops/mod.rs @@ -98,6 +98,7 @@ pub use moment::*; pub use negate::*; #[cfg(feature = "pct_change")] pub use pct_change::*; +pub use polars_core::chunked_array::ops::search_sorted::SearchSortedSide; use polars_core::prelude::*; #[cfg(feature = "rank")] pub use rank::*; diff --git a/crates/polars-ops/src/series/ops/rank.rs b/crates/polars-ops/src/series/ops/rank.rs index 8bcc3347fc665..dd2fe3936945f 100644 --- a/crates/polars-ops/src/series/ops/rank.rs +++ b/crates/polars-ops/src/series/ops/rank.rs @@ -1,11 +1,7 @@ use arrow::array::BooleanArray; use arrow::compute::concatenate::concatenate_validities; use polars_core::prelude::*; -#[cfg(feature = "random")] -use rand::prelude::SliceRandom; use rand::prelude::*; -#[cfg(feature = "random")] -use rand::{rngs::SmallRng, SeedableRng}; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; diff --git a/crates/polars-ops/src/series/ops/replace.rs b/crates/polars-ops/src/series/ops/replace.rs index 752355b68242b..c169bff7f70de 100644 --- a/crates/polars-ops/src/series/ops/replace.rs +++ b/crates/polars-ops/src/series/ops/replace.rs @@ -2,11 +2,10 @@ use std::ops::BitOr; use polars_core::prelude::*; use polars_core::utils::try_get_supertype; -use polars_error::{polars_bail, polars_ensure, PolarsResult}; +use polars_error::polars_ensure; use crate::frame::join::*; use crate::prelude::*; -use crate::series::is_in; pub fn replace( s: &Series, @@ -25,19 +24,21 @@ pub fn replace( None => try_get_supertype(new.dtype(), default.dtype())?, }; - let default = match default.len() { - len if len == s.len() => default.cast(&return_dtype)?, - 1 => default.cast(&return_dtype)?.new_from_index(0, s.len()), - _ => { - polars_bail!( - ComputeError: - "`default` input for `replace` must have the same length as the input or have length 1" - ) - }, - }; + polars_ensure!( + default.len() == s.len() || default.len() == 1, + ComputeError: "`default` input for `replace` must have the same length as the input or have length 1" + ); + + let default = default.cast(&return_dtype)?; if old.len() == 0 { - return Ok(default); + let out = if default.len() == 1 && s.len() != 1 { + default.new_from_index(0, s.len()) + } else { + default + }; + + return Ok(out); } let old = match (s.dtype(), old.dtype()) { @@ -90,7 +91,7 @@ fn replace_by_multiple( ComputeError: "`new` input for `replace` must have the same length as `old` or have length 1" ); - let df = DataFrame::new_no_checks(vec![s.clone()]); + let df = s.clone().into_frame(); let replacer = create_replacer(old, new)?; let joined = df.join( @@ -133,6 +134,6 @@ fn create_replacer(mut old: Series, mut new: Series) -> PolarsResult } else { vec![old, new] }; - let out = DataFrame::new_no_checks(cols); + let out = unsafe { DataFrame::new_no_checks(cols) }; Ok(out) } diff --git a/crates/polars-ops/src/series/ops/search_sorted.rs b/crates/polars-ops/src/series/ops/search_sorted.rs index 09f0835481241..b7376d9b553f2 100644 --- a/crates/polars-ops/src/series/ops/search_sorted.rs +++ b/crates/polars-ops/src/series/ops/search_sorted.rs @@ -1,132 +1,7 @@ -use std::cmp::Ordering; -use std::fmt::Debug; - use arrow::array::Array; -use arrow::legacy::prelude::*; +use polars_core::chunked_array::ops::search_sorted::{binary_search_array, SearchSortedSide}; use polars_core::prelude::*; use polars_core::with_match_physical_numeric_polars_type; -use polars_utils::total_ord::{TotalEq, TotalOrd}; -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; - -#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, Default)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub enum SearchSortedSide { - #[default] - Any, - Left, - Right, -} - -/// Search the left or right index that still fulfills the requirements. -fn finish_side<'a, A>( - side: SearchSortedSide, - out: &mut Vec, - mid: IdxSize, - arr: &'a A, - len: usize, -) where - A: StaticArray, - A::ValueT<'a>: TotalOrd + Debug + Copy, -{ - let mut mid = mid; - - // approach the boundary from any side - // this is O(n) we could make this binary search later - match side { - SearchSortedSide::Any => { - out.push(mid); - }, - SearchSortedSide::Left => { - if mid as usize == len { - mid -= 1; - } - - let current = unsafe { arr.get_unchecked(mid as usize) }; - loop { - if mid == 0 { - out.push(mid); - break; - } - mid -= 1; - if current.tot_ne(unsafe { &arr.get_unchecked(mid as usize) }) { - out.push(mid + 1); - break; - } - } - }, - SearchSortedSide::Right => { - if mid as usize == len { - out.push(mid); - return; - } - let current = unsafe { arr.get_unchecked(mid as usize) }; - let bound = (len - 1) as IdxSize; - loop { - if mid >= bound { - out.push(mid + 1); - break; - } - mid += 1; - if current.tot_ne(unsafe { &arr.get_unchecked(mid as usize) }) { - out.push(mid); - break; - } - } - }, - } -} - -fn binary_search_array<'a, A>( - side: SearchSortedSide, - out: &mut Vec, - arr: &'a A, - len: usize, - search_value: A::ValueT<'a>, - descending: bool, -) where - A: StaticArray, - A::ValueT<'a>: TotalOrd + Debug + Copy, -{ - let mut size = len as IdxSize; - let mut left = 0 as IdxSize; - let mut right = size; - let current_len = out.len(); - while left < right { - let mid = left + size / 2; - - // SAFETY: the call is made safe by the following invariants: - // - `mid >= 0` - // - `mid < size`: `mid` is limited by `[left; right)` bound. - let cmp = match unsafe { arr.get_unchecked(mid as usize) } { - None => Ordering::Less, - Some(value) => { - if descending { - search_value.tot_cmp(&value) - } else { - value.tot_cmp(&search_value) - } - }, - }; - - // The reason why we use if/else control flow rather than match - // is because match reorders comparison operations, which is perf sensitive. - // This is x86 asm for u8: https://rust.godbolt.org/z/8Y8Pra. - if cmp == Ordering::Less { - left = mid + 1; - } else if cmp == Ordering::Greater { - right = mid; - } else { - finish_side(side, out, mid, arr, len); - break; - } - - size = right - left; - } - if out.len() == current_len { - out.push(left); - } -} fn search_sorted_ca_array( ca: &ChunkedArray, @@ -145,20 +20,15 @@ where for search_arr in search_values.downcast_iter() { if search_arr.null_count() == 0 { for search_value in search_arr.values_iter() { - binary_search_array(side, &mut out, arr, ca.len(), *search_value, descending) + out.push(binary_search_array(side, arr, *search_value, descending)) } } else { for opt_v in search_arr.into_iter() { match opt_v { None => out.push(0), - Some(search_value) => binary_search_array( - side, - &mut out, - arr, - ca.len(), - *search_value, - descending, - ), + Some(search_value) => { + out.push(binary_search_array(side, arr, *search_value, descending)) + }, } } } @@ -180,14 +50,14 @@ fn search_sorted_bin_array_with_binary_offset( for search_arr in search_values.downcast_iter() { if search_arr.null_count() == 0 { for search_value in search_arr.values_iter() { - binary_search_array(side, &mut out, arr, ca.len(), search_value, descending) + out.push(binary_search_array(side, arr, search_value, descending)) } } else { for opt_v in search_arr.into_iter() { match opt_v { None => out.push(0), Some(search_value) => { - binary_search_array(side, &mut out, arr, ca.len(), search_value, descending) + out.push(binary_search_array(side, arr, search_value, descending)) }, } } @@ -210,14 +80,14 @@ fn search_sorted_bin_array( for search_arr in search_values.downcast_iter() { if search_arr.null_count() == 0 { for search_value in search_arr.values_iter() { - binary_search_array(side, &mut out, arr, ca.len(), search_value, descending) + out.push(binary_search_array(side, arr, search_value, descending)) } } else { for opt_v in search_arr.into_iter() { match opt_v { None => out.push(0), Some(search_value) => { - binary_search_array(side, &mut out, arr, ca.len(), search_value, descending) + out.push(binary_search_array(side, arr, search_value, descending)) }, } } diff --git a/crates/polars-ops/src/series/ops/to_dummies.rs b/crates/polars-ops/src/series/ops/to_dummies.rs index 8b59192dcaa06..f2d8c4f3b70a8 100644 --- a/crates/polars-ops/src/series/ops/to_dummies.rs +++ b/crates/polars-ops/src/series/ops/to_dummies.rs @@ -1,5 +1,3 @@ -use polars_core::frame::group_by::GroupsIndicator; - use super::*; #[cfg(feature = "dtype-u8")] @@ -48,7 +46,7 @@ impl ToDummies for Series { }) .collect(); - Ok(DataFrame::new_no_checks(sort_columns(columns))) + Ok(unsafe { DataFrame::new_no_checks(sort_columns(columns)) }) } } diff --git a/crates/polars-ops/src/series/ops/unique.rs b/crates/polars-ops/src/series/ops/unique.rs index e35847b120a89..3a2d9b5652fe8 100644 --- a/crates/polars-ops/src/series/ops/unique.rs +++ b/crates/polars-ops/src/series/ops/unique.rs @@ -3,14 +3,18 @@ use std::hash::Hash; use polars_core::hashing::_HASHMAP_INIT_SIZE; use polars_core::prelude::*; use polars_core::utils::NoNull; +use polars_core::with_match_physical_numeric_polars_type; +use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash}; fn unique_counts_helper(items: I) -> IdxCa where I: Iterator, - J: Hash + Eq, + J: TotalHash + TotalEq + ToTotalOrd, + ::TotalOrdItem: Hash + Eq, { let mut map = PlIndexMap::with_capacity_and_hasher(_HASHMAP_INIT_SIZE, Default::default()); for item in items { + let item = item.to_total_ord(); map.entry(item) .and_modify(|cnt| { *cnt += 1; @@ -24,13 +28,12 @@ where /// Returns a count of the unique values in the order of appearance. pub fn unique_counts(s: &Series) -> PolarsResult { if s.dtype().to_physical().is_numeric() { - if s.bit_repr_is_large() { - let ca = s.bit_repr_large(); - Ok(unique_counts_helper(ca.iter()).into_series()) - } else { - let ca = s.bit_repr_small(); + let s_physical = s.to_physical_repr(); + + with_match_physical_numeric_polars_type!(s_physical.dtype(), |$T| { + let ca: &ChunkedArray<$T> = s_physical.as_ref().as_ref().as_ref(); Ok(unique_counts_helper(ca.iter()).into_series()) - } + }) } else { match s.dtype() { DataType::String => { diff --git a/crates/polars-ops/src/series/ops/various.rs b/crates/polars-ops/src/series/ops/various.rs index ecc341ea020e3..cad413816ced4 100644 --- a/crates/polars-ops/src/series/ops/various.rs +++ b/crates/polars-ops/src/series/ops/various.rs @@ -1,5 +1,3 @@ -#[cfg(feature = "hash")] -use polars_core::export::ahash; #[cfg(feature = "dtype-struct")] use polars_core::prelude::sort::arg_sort_multiple::_get_rows_encoded_ca; use polars_core::prelude::*; @@ -21,7 +19,7 @@ pub trait SeriesMethods: SeriesSealed { let values = unsafe { s.agg_first(&groups) }; let counts = groups.group_lengths("count"); let cols = vec![values, counts.into_series()]; - let df = DataFrame::new_no_checks(cols); + let df = unsafe { DataFrame::new_no_checks(cols) }; if sort { df.sort(["count"], true, false) } else { diff --git a/crates/polars-parquet/src/arrow/read/deserialize/binary/basic.rs b/crates/polars-parquet/src/arrow/read/deserialize/binary/basic.rs index 2ecc2c51e223d..aa74d1cc9b4a8 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/binary/basic.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/binary/basic.rs @@ -8,6 +8,7 @@ use arrow::bitmap::MutableBitmap; use arrow::datatypes::{ArrowDataType, PhysicalType}; use arrow::offset::Offset; use polars_error::PolarsResult; +use polars_utils::iter::FallibleIterator; use super::super::utils::{extend_from_decoder, next, DecodedState, MaybeNext}; use super::super::{utils, PagesIter}; @@ -121,8 +122,9 @@ impl<'a, O: Offset> utils::Decoder<'a> for BinaryDecoder { &mut page_values .values .by_ref() - .map(|index| page_dict.value(index.unwrap() as usize)), - ) + .map(|index| page_dict.value(index as usize)), + ); + page_values.values.get_result()?; }, BinaryState::RequiredDictionary(page) => { // Already done on the dict. @@ -132,11 +134,12 @@ impl<'a, O: Offset> utils::Decoder<'a> for BinaryDecoder { for x in page .values .by_ref() - .map(|index| page_dict.value(index.unwrap() as usize)) + .map(|index| page_dict.value(index as usize)) .take(additional) { values.push(x) } + page.values.get_result()?; }, BinaryState::FilteredOptional(page_validity, page_values) => { extend_from_decoder( @@ -160,14 +163,15 @@ impl<'a, O: Offset> utils::Decoder<'a> for BinaryDecoder { // Already done on the dict. validate_utf8 = false; let page_dict = &page.dict; - for x in page + for x in &mut page .values .by_ref() - .map(|index| page_dict.value(index.unwrap() as usize)) + .map(|index| page_dict.value(index as usize)) .take(additional) { values.push(x) } + page.values.iter.get_result()?; }, BinaryState::FilteredOptionalDictionary(page_validity, page_values) => { // Already done on the dict. @@ -181,8 +185,9 @@ impl<'a, O: Offset> utils::Decoder<'a> for BinaryDecoder { &mut page_values .values .by_ref() - .map(|index| page_dict.value(index.unwrap() as usize)), - ) + .map(|index| page_dict.value(index as usize)), + ); + page_values.values.get_result()?; }, BinaryState::OptionalDeltaByteArray(page_validity, page_values) => extend_from_decoder( validity, diff --git a/crates/polars-parquet/src/arrow/read/deserialize/binary/nested.rs b/crates/polars-parquet/src/arrow/read/deserialize/binary/nested.rs index f3f6c9226e7cd..b2c5129833ff4 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/binary/nested.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/binary/nested.rs @@ -5,6 +5,7 @@ use arrow::bitmap::MutableBitmap; use arrow::datatypes::ArrowDataType; use arrow::offset::Offset; use polars_error::PolarsResult; +use polars_utils::iter::FallibleIterator; use super::super::nested_utils::*; use super::super::utils::MaybeNext; @@ -60,17 +61,19 @@ impl<'a, O: Offset> NestedDecoder<'a> for BinaryDecoder { let item = page .values .next() - .map(|index| dict_values.value(index.unwrap() as usize)) + .map(|index| dict_values.value(index as usize)) .unwrap_or_default(); values.push(item); + page.values.get_result()?; }, BinaryNestedState::OptionalDictionary(page) => { let dict_values = &page.dict; let item = page .values .next() - .map(|index| dict_values.value(index.unwrap() as usize)) + .map(|index| dict_values.value(index as usize)) .unwrap_or_default(); + page.values.get_result()?; values.push(item); validity.push(true); }, diff --git a/crates/polars-parquet/src/arrow/read/deserialize/binary/utils.rs b/crates/polars-parquet/src/arrow/read/deserialize/binary/utils.rs index 13c01d9bca624..11a16351ea458 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/binary/utils.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/binary/utils.rs @@ -80,6 +80,11 @@ impl<'a, O: Offset> Pushable<&'a [u8]> for Binary { assert_eq!(value.len(), 0); self.extend_constant(additional) } + + #[inline] + fn extend_null_constant(&mut self, additional: usize) { + self.extend_constant(additional) + } } #[derive(Debug)] diff --git a/crates/polars-parquet/src/arrow/read/deserialize/binview/basic.rs b/crates/polars-parquet/src/arrow/read/deserialize/binview/basic.rs index 8f4806f400563..1b3d657992932 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/binview/basic.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/binview/basic.rs @@ -5,6 +5,7 @@ use arrow::array::{Array, ArrayRef, BinaryViewArray, MutableBinaryViewArray, Utf use arrow::bitmap::{Bitmap, MutableBitmap}; use arrow::datatypes::{ArrowDataType, PhysicalType}; use polars_error::PolarsResult; +use polars_utils::iter::FallibleIterator; use super::super::binary::decoders::*; use crate::parquet::page::{DataPage, DictPage}; @@ -57,6 +58,8 @@ impl<'a> utils::Decoder<'a> for BinViewDecoder { additional: usize, ) -> PolarsResult<()> { let (values, validity) = decoded; + let views_offset = values.views().len(); + let buffer_offset = values.completed_buffers().len(); let mut validate_utf8 = self.check_utf8.take(); match state { @@ -108,8 +111,9 @@ impl<'a> utils::Decoder<'a> for BinViewDecoder { &mut page_values .values .by_ref() - .map(|index| page_dict.value(index.unwrap() as usize)), - ) + .map(|index| page_dict.value(index as usize)), + ); + page_values.values.get_result()?; }, BinaryState::RequiredDictionary(page) => { // Already done on the dict. @@ -119,11 +123,12 @@ impl<'a> utils::Decoder<'a> for BinViewDecoder { for x in page .values .by_ref() - .map(|index| page_dict.value(index.unwrap() as usize)) + .map(|index| page_dict.value(index as usize)) .take(additional) { values.push_value_ignore_validity(x) } + page.values.get_result()?; }, BinaryState::FilteredOptional(page_validity, page_values) => { extend_from_decoder( @@ -152,11 +157,12 @@ impl<'a> utils::Decoder<'a> for BinViewDecoder { for x in page .values .by_ref() - .map(|index| page_dict.value(index.unwrap() as usize)) + .map(|index| page_dict.value(index as usize)) .take(additional) { values.push_value_ignore_validity(x) } + page.values.iter.get_result()?; }, BinaryState::FilteredOptionalDictionary(page_validity, page_values) => { // Already done on the dict. @@ -172,8 +178,9 @@ impl<'a> utils::Decoder<'a> for BinViewDecoder { &mut page_values .values .by_ref() - .map(|index| page_dict.value(index.unwrap() as usize)), - ) + .map(|index| page_dict.value(index as usize)), + ); + page_values.values.get_result()?; }, BinaryState::OptionalDeltaByteArray(page_validity, page_values) => extend_from_decoder( validity, @@ -190,7 +197,7 @@ impl<'a> utils::Decoder<'a> for BinViewDecoder { } if validate_utf8 { - values.validate_utf8() + values.validate_utf8(buffer_offset, views_offset) } else { Ok(()) } diff --git a/crates/polars-parquet/src/arrow/read/deserialize/binview/nested.rs b/crates/polars-parquet/src/arrow/read/deserialize/binview/nested.rs index 4195265550d17..2b6b66f5e4f4f 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/binview/nested.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/binview/nested.rs @@ -4,6 +4,7 @@ use arrow::array::{ArrayRef, MutableBinaryViewArray}; use arrow::bitmap::MutableBitmap; use arrow::datatypes::ArrowDataType; use polars_error::PolarsResult; +use polars_utils::iter::FallibleIterator; use crate::parquet::page::{DataPage, DictPage}; use crate::read::deserialize::binary::decoders::{ @@ -60,19 +61,21 @@ impl<'a> NestedDecoder<'a> for BinViewDecoder { let item = page .values .next() - .map(|index| dict_values.value(index.unwrap() as usize)) + .map(|index| dict_values.value(index as usize)) .unwrap_or_default(); values.push_value_ignore_validity(item); + page.values.get_result()?; }, BinaryNestedState::OptionalDictionary(page) => { let dict_values = &page.dict; let item = page .values .next() - .map(|index| dict_values.value(index.unwrap() as usize)) + .map(|index| dict_values.value(index as usize)) .unwrap_or_default(); values.push_value_ignore_validity(item); validity.push(true); + page.values.get_result()?; }, } Ok(()) diff --git a/crates/polars-parquet/src/arrow/read/deserialize/boolean/basic.rs b/crates/polars-parquet/src/arrow/read/deserialize/boolean/basic.rs index 5f2d1107cd49b..aab7a6e91be91 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/boolean/basic.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/boolean/basic.rs @@ -5,6 +5,7 @@ use arrow::bitmap::utils::BitmapIter; use arrow::bitmap::MutableBitmap; use arrow::datatypes::ArrowDataType; use polars_error::PolarsResult; +use polars_utils::iter::FallibleIterator; use super::super::utils::{ extend_from_decoder, get_selected_rows, next, DecodedState, Decoder, @@ -144,7 +145,6 @@ impl<'a> Decoder<'a> for BooleanDecoder { let iter = hybrid_rle::Decoder::new(values, 1); let values = HybridDecoderBitmapIter::new(iter, page.num_values()); let values = HybridRleBooleanIter::new(values); - Ok(State::RleOptional(optional, values)) }, _ => Err(utils::not_implemented(page)), @@ -199,8 +199,9 @@ impl<'a> Decoder<'a> for BooleanDecoder { page_validity, Some(remaining), values, - page_values.map(|v| v.unwrap()), + &mut *page_values, ); + page_values.get_result()?; }, } Ok(()) diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary/mod.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary/mod.rs index 2facc235eb7a0..f9401771706ca 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary/mod.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary/mod.rs @@ -154,29 +154,28 @@ where ) -> PolarsResult<()> { let (values, validity) = decoded; match state { - State::Optional(page) => extend_from_decoder( - validity, - &mut page.validity, - Some(remaining), - values, - &mut page.values.by_ref().map(|x| { - // todo: rm unwrap - let x: usize = x.unwrap().try_into().unwrap(); - match x.try_into() { - Ok(key) => key, - // todo: convert this to an error. - Err(_) => panic!("The maximum key is too small"), - } - }), - ), + State::Optional(page) => { + extend_from_decoder( + validity, + &mut page.validity, + Some(remaining), + values, + &mut page.values.by_ref().map(|x| { + match (x as usize).try_into() { + Ok(key) => key, + // todo: convert this to an error. + Err(_) => panic!("The maximum key is too small"), + } + }), + ); + page.values.get_result()?; + }, State::Required(page) => { values.extend( page.values .by_ref() .map(|x| { - // todo: rm unwrap - let x: usize = x.unwrap().try_into().unwrap(); - let x: K = match x.try_into() { + let x: K = match (x as usize).try_into() { Ok(key) => key, // todo: convert this to an error. Err(_) => { @@ -187,33 +186,33 @@ where }) .take(remaining), ); + page.values.get_result()?; + }, + State::FilteredOptional(page_validity, page_values) => { + extend_from_decoder( + validity, + page_validity, + Some(remaining), + values, + &mut page_values.by_ref().map(|x| { + let x: K = match (x as usize).try_into() { + Ok(key) => key, + // todo: convert this to an error. + Err(_) => { + panic!("The maximum key is too small") + }, + }; + x + }), + ); + page_values.get_result()?; }, - State::FilteredOptional(page_validity, page_values) => extend_from_decoder( - validity, - page_validity, - Some(remaining), - values, - &mut page_values.by_ref().map(|x| { - // todo: rm unwrap - let x: usize = x.unwrap().try_into().unwrap(); - let x: K = match x.try_into() { - Ok(key) => key, - // todo: convert this to an error. - Err(_) => { - panic!("The maximum key is too small") - }, - }; - x - }), - ), State::FilteredRequired(page) => { values.extend( page.values .by_ref() .map(|x| { - // todo: rm unwrap - let x: usize = x.unwrap().try_into().unwrap(); - let x: K = match x.try_into() { + let x: K = match (x as usize).try_into() { Ok(key) => key, // todo: convert this to an error. Err(_) => { @@ -224,6 +223,7 @@ where }) .take(remaining), ); + page.values.iter.get_result()?; }, } Ok(()) @@ -319,3 +319,4 @@ pub(super) fn next_dict Box< pub use nested::next_dict as nested_next_dict; use polars_error::{polars_err, PolarsResult}; +use polars_utils::iter::FallibleIterator; diff --git a/crates/polars-parquet/src/arrow/read/deserialize/dictionary/nested.rs b/crates/polars-parquet/src/arrow/read/deserialize/dictionary/nested.rs index 885f542be98ea..8e75eebaff10f 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/dictionary/nested.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/dictionary/nested.rs @@ -4,6 +4,7 @@ use arrow::array::{Array, DictionaryArray, DictionaryKey}; use arrow::bitmap::MutableBitmap; use arrow::datatypes::ArrowDataType; use polars_error::{polars_err, PolarsResult}; +use polars_utils::iter::FallibleIterator; use super::super::super::PagesIter; use super::super::nested_utils::*; @@ -112,20 +113,18 @@ impl<'a, K: DictionaryKey> NestedDecoder<'a> for DictionaryDecoder { let (values, validity) = decoded; match state { State::Optional(page_values) => { - let key = page_values.next().transpose()?; - // todo: convert unwrap to error - let key = match K::try_from(key.unwrap_or_default() as usize) { - Ok(key) => key, - Err(_) => todo!(), + let key = page_values.next().unwrap_or_default(); + let Ok(key) = K::try_from(key as usize) else { + panic! {} }; values.push(key); validity.push(true); + page_values.get_result()?; }, State::Required(page_values) => { - let key = page_values.values.next().transpose()?; - let key = match K::try_from(key.unwrap_or_default() as usize) { - Ok(key) => key, - Err(_) => todo!(), + let key = page_values.values.next().unwrap_or_default(); + let Ok(key) = K::try_from(key as usize) else { + panic! {} }; values.push(key); }, diff --git a/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary/basic.rs b/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary/basic.rs index 4cb8d146f3b30..7490d4e92c4c4 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary/basic.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary/basic.rs @@ -5,6 +5,7 @@ use arrow::bitmap::MutableBitmap; use arrow::datatypes::ArrowDataType; use arrow::pushable::Pushable; use polars_error::PolarsResult; +use polars_utils::iter::FallibleIterator; use super::super::utils::{ dict_indices_decoder, extend_from_decoder, get_selected_rows, next, not_implemented, @@ -229,28 +230,32 @@ impl<'a> Decoder<'a> for BinaryDecoder { values.push(x) } }, - State::OptionalDictionary(page) => extend_from_decoder( - validity, - &mut page.validity, - Some(remaining), - values, - page.values.by_ref().map(|index| { - let index = index.unwrap() as usize; - &page.dict[index * self.size..(index + 1) * self.size] - }), - ), + State::OptionalDictionary(page) => { + extend_from_decoder( + validity, + &mut page.validity, + Some(remaining), + values, + page.values.by_ref().map(|index| { + let index = index as usize; + &page.dict[index * self.size..(index + 1) * self.size] + }), + ); + page.values.get_result()?; + }, State::RequiredDictionary(page) => { for x in page .values .by_ref() .map(|index| { - let index = index.unwrap() as usize; + let index = index as usize; &page.dict[index * self.size..(index + 1) * self.size] }) .take(remaining) { values.push(x) } + page.values.get_result()?; }, State::FilteredOptional(page_validity, page_values) => { extend_from_decoder( diff --git a/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary/nested.rs b/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary/nested.rs index 2320abebfb9e8..307862ba3a53e 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary/nested.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary/nested.rs @@ -5,6 +5,7 @@ use arrow::bitmap::MutableBitmap; use arrow::datatypes::ArrowDataType; use arrow::pushable::Pushable; use polars_error::PolarsResult; +use polars_utils::iter::FallibleIterator; use super::super::utils::{not_implemented, MaybeNext, PageState}; use super::utils::FixedSizeBinary; @@ -101,11 +102,12 @@ impl<'a> NestedDecoder<'a> for BinaryDecoder { .by_ref() .next() .map(|index| { - let index = index.unwrap() as usize; + let index = index as usize; &page.dict[index * self.size..(index + 1) * self.size] }) .unwrap_or_default(); values.push(item); + page.values.get_result()?; }, State::OptionalDictionary(page) => { let item = page @@ -113,12 +115,13 @@ impl<'a> NestedDecoder<'a> for BinaryDecoder { .by_ref() .next() .map(|index| { - let index = index.unwrap() as usize; + let index = index as usize; &page.dict[index * self.size..(index + 1) * self.size] }) .unwrap_or_default(); values.push(item); validity.push(true); + page.values.get_result()?; }, } Ok(()) diff --git a/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary/utils.rs b/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary/utils.rs index 219b0a51ba2e0..50a89442570a4 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary/utils.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/fixed_size_binary/utils.rs @@ -55,4 +55,9 @@ impl<'a> Pushable<&'a [u8]> for FixedSizeBinary { fn len(&self) -> usize { self.values.len() / self.size } + + #[inline] + fn extend_null_constant(&mut self, additional: usize) { + self.extend_constant(additional) + } } diff --git a/crates/polars-parquet/src/arrow/read/deserialize/nested.rs b/crates/polars-parquet/src/arrow/read/deserialize/nested.rs index b62a016ad576e..25507816b8fef 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/nested.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/nested.rs @@ -1,12 +1,9 @@ use arrow::array::PrimitiveArray; -use arrow::datatypes::{ArrowDataType, Field}; use arrow::match_integer_type; use ethnum::I256; use polars_error::polars_bail; -use super::nested_utils::{InitNested, NestedArrayIter}; use super::*; -use crate::parquet::schema::types::PrimitiveType; /// Converts an iterator of arrays to a trait object returning trait objects #[inline] @@ -407,7 +404,7 @@ where }, PhysicalType::FixedLenByteArray(n) => { polars_bail!(ComputeError: - "Can't decode Decimal256 type from from `FixedLenByteArray` of len {n}" + "Can't decode Decimal256 type from `FixedLenByteArray` of len {n}" ) }, _ => { diff --git a/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs b/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs index c292072e7b75c..6590085283ee2 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/nested_utils.rs @@ -2,7 +2,8 @@ use std::collections::VecDeque; use arrow::array::Array; use arrow::bitmap::MutableBitmap; -use polars_error::PolarsResult; +use polars_error::{polars_bail, PolarsResult}; +use polars_utils::slice::GetSaferUnchecked; use super::super::PagesIter; use super::utils::{DecodedState, MaybeNext, PageState}; @@ -379,6 +380,9 @@ pub(super) fn extend<'a, D: NestedDecoder<'a>>( let chunk_size = chunk_size.unwrap_or(usize::MAX); let mut first_item_is_fully_read = false; + // Amortize the allocations. + let mut cum_sum = vec![]; + let mut cum_rep = vec![]; loop { if let Some((mut nested, mut decoded)) = items.pop_back() { @@ -392,6 +396,8 @@ pub(super) fn extend<'a, D: NestedDecoder<'a>>( &mut decoded, decoder, additional, + &mut cum_sum, + &mut cum_rep, )?; first_item_is_fully_read |= is_fully_read; *remaining -= nested.len() - existing; @@ -418,6 +424,7 @@ pub(super) fn extend<'a, D: NestedDecoder<'a>>( Ok(first_item_is_fully_read) } +#[allow(clippy::too_many_arguments)] fn extend_offsets2<'a, D: NestedDecoder<'a>>( page: &mut NestedPage<'a>, values_state: &mut D::State, @@ -425,19 +432,26 @@ fn extend_offsets2<'a, D: NestedDecoder<'a>>( decoded: &mut D::DecodedState, decoder: &D, additional: usize, + // Amortized allocations + cum_sum: &mut Vec, + cum_rep: &mut Vec, ) -> PolarsResult { let max_depth = nested.len(); - let mut cum_sum = vec![0u32; max_depth + 1]; + cum_sum.resize(max_depth + 1, 0); + cum_rep.resize(max_depth + 1, 0); for (i, nest) in nested.iter().enumerate() { let delta = nest.is_nullable() as u32 + nest.is_repeated() as u32; - cum_sum[i + 1] = cum_sum[i] + delta; + unsafe { + *cum_sum.get_unchecked_release_mut(i + 1) = *cum_sum.get_unchecked_release(i) + delta; + } } - let mut cum_rep = vec![0u32; max_depth + 1]; for (i, nest) in nested.iter().enumerate() { let delta = nest.is_repeated() as u32; - cum_rep[i + 1] = cum_rep[i] + delta; + unsafe { + *cum_rep.get_unchecked_release_mut(i + 1) = *cum_rep.get_unchecked_release(i) + delta; + } } let mut rows = 0; @@ -447,51 +461,58 @@ fn extend_offsets2<'a, D: NestedDecoder<'a>>( // yield batches of pages. This means e.g. it could be that the very // first page is a new row, and the existing nested state has already // contains all data from the additional rows. - if page.iter.peek().unwrap().0.as_ref().copied().unwrap() == 0 { + if page.iter.peek().unwrap().0 == 0 { if rows == additional { return Ok(true); } rows += 1; } - let (rep, def) = page.iter.next().unwrap(); - let rep = rep?; - let def = def?; + // The errors of the FallibleIterators use in this zipped not checked yet. + // If one of them errors, the iterator returns None, and this `unwrap` will panic. + let Some((rep, def)) = page.iter.next() else { + polars_bail!(ComputeError: "cannot read rep/def levels") + }; let mut is_required = false; - for depth in 0..max_depth { - let right_level = rep <= cum_rep[depth] && def >= cum_sum[depth]; - if is_required || right_level { - let length = nested - .get(depth + 1) - .map(|x| x.len() as i64) - // the last depth is the leaf, which is always increased by 1 - .unwrap_or(1); - - let nest = &mut nested[depth]; - - let is_valid = nest.is_nullable() && def > cum_sum[depth]; - nest.push(length, is_valid); - is_required = nest.is_required() && !is_valid; - - if depth == max_depth - 1 { - // the leaf / primitive - let is_valid = (def != cum_sum[depth]) || !nest.is_nullable(); - if right_level && is_valid { - decoder.push_valid(values_state, decoded)?; - } else { - decoder.push_null(decoded); + + // SAFETY: only bound check elision. + unsafe { + for depth in 0..max_depth { + let right_level = rep <= *cum_rep.get_unchecked_release(depth) + && def >= *cum_sum.get_unchecked_release(depth); + if is_required || right_level { + let length = nested + .get(depth + 1) + .map(|x| x.len() as i64) + // the last depth is the leaf, which is always increased by 1 + .unwrap_or(1); + + let nest = nested.get_unchecked_release_mut(depth); + + let is_valid = + nest.is_nullable() && def > *cum_sum.get_unchecked_release(depth); + nest.push(length, is_valid); + is_required = nest.is_required() && !is_valid; + + if depth == max_depth - 1 { + // the leaf / primitive + let is_valid = + (def != *cum_sum.get_unchecked_release(depth)) || !nest.is_nullable(); + if right_level && is_valid { + decoder.push_valid(values_state, decoded)?; + } else { + decoder.push_null(decoded); + } } } } } if page.iter.len() == 0 { - break; + return Ok(false); } } - - Ok(false) } #[inline] diff --git a/crates/polars-parquet/src/arrow/read/deserialize/null/mod.rs b/crates/polars-parquet/src/arrow/read/deserialize/null/mod.rs index 6e02a57e65fb8..da512675a4d42 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/null/mod.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/null/mod.rs @@ -61,6 +61,7 @@ mod tests { use super::iter_to_arrays; use crate::parquet::encoding::Encoding; use crate::parquet::error::Error as ParquetError; + #[allow(unused_imports)] use crate::parquet::fallible_streaming_iterator; use crate::parquet::metadata::Descriptor; use crate::parquet::page::{DataPage, DataPageHeader, DataPageHeaderV1, Page}; diff --git a/crates/polars-parquet/src/arrow/read/deserialize/primitive/basic.rs b/crates/polars-parquet/src/arrow/read/deserialize/primitive/basic.rs index 32c3a221165a2..c8186466f5476 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/primitive/basic.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/primitive/basic.rs @@ -5,6 +5,7 @@ use arrow::bitmap::MutableBitmap; use arrow::datatypes::ArrowDataType; use arrow::types::NativeType; use polars_error::PolarsResult; +use polars_utils::iter::FallibleIterator; use super::super::utils::{ get_selected_rows, FilteredOptionalPageValidity, MaybeNext, OptionalPageValidity, @@ -232,18 +233,14 @@ where page_validity, Some(remaining), values, - &mut page_values.values.by_ref().map(|x| x.unwrap()).map(op1), - ) + &mut page_values.values.by_ref().map(op1), + ); + page_values.values.get_result()?; }, State::RequiredDictionary(page) => { let op1 = |index: u32| page.dict[index as usize]; - values.extend( - page.values - .by_ref() - .map(|x| x.unwrap()) - .map(op1) - .take(remaining), - ); + values.extend(page.values.by_ref().map(op1).take(remaining)); + page.values.get_result()?; }, State::FilteredRequired(page) => { values.extend( diff --git a/crates/polars-parquet/src/arrow/read/deserialize/primitive/dictionary.rs b/crates/polars-parquet/src/arrow/read/deserialize/primitive/dictionary.rs index af53485ff93fe..1a1759252ab89 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/primitive/dictionary.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/primitive/dictionary.rs @@ -6,7 +6,7 @@ use arrow::datatypes::ArrowDataType; use arrow::types::NativeType; use polars_error::PolarsResult; -use super::super::dictionary::{nested_next_dict, *}; +use super::super::dictionary::*; use super::super::nested_utils::{InitNested, NestedState}; use super::super::utils::MaybeNext; use super::super::PagesIter; diff --git a/crates/polars-parquet/src/arrow/read/deserialize/primitive/nested.rs b/crates/polars-parquet/src/arrow/read/deserialize/primitive/nested.rs index 978ddc76233dd..8ab22737153c0 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/primitive/nested.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/primitive/nested.rs @@ -5,6 +5,7 @@ use arrow::bitmap::MutableBitmap; use arrow::datatypes::ArrowDataType; use arrow::types::NativeType; use polars_error::PolarsResult; +use polars_utils::iter::FallibleIterator; use super::super::nested_utils::*; use super::super::utils::MaybeNext; @@ -129,21 +130,17 @@ where values.push(value.unwrap_or_default()); }, State::RequiredDictionary(page) => { - let value = page - .values - .next() - .map(|index| page.dict[index.unwrap() as usize]); + let value = page.values.next().map(|index| page.dict[index as usize]); values.push(value.unwrap_or_default()); + page.values.get_result()?; }, State::OptionalDictionary(page) => { - let value = page - .values - .next() - .map(|index| page.dict[index.unwrap() as usize]); + let value = page.values.next().map(|index| page.dict[index as usize]); values.push(value.unwrap_or_default()); validity.push(true); + page.values.get_result()?; }, } Ok(()) diff --git a/crates/polars-parquet/src/arrow/read/deserialize/utils.rs b/crates/polars-parquet/src/arrow/read/deserialize/utils.rs index f41f98b46bf16..6919dd88dd746 100644 --- a/crates/polars-parquet/src/arrow/read/deserialize/utils.rs +++ b/crates/polars-parquet/src/arrow/read/deserialize/utils.rs @@ -227,7 +227,7 @@ impl<'a> PageValidity<'a> for OptionalPageValidity<'a> { } } -fn reserve_pushable_and_validity<'a, T: Default, P: Pushable>( +fn reserve_pushable_and_validity<'a, T, P: Pushable>( validity: &mut MutableBitmap, page_validity: &'a mut dyn PageValidity, limit: Option, @@ -263,7 +263,7 @@ fn reserve_pushable_and_validity<'a, T: Default, P: Pushable>( } /// Extends a [`Pushable`] from an iterator of non-null values and an hybrid-rle decoder -pub(super) fn extend_from_decoder, I: Iterator>( +pub(super) fn extend_from_decoder, I: Iterator>( validity: &mut MutableBitmap, page_validity: &mut dyn PageValidity, limit: Option, @@ -300,7 +300,7 @@ pub(super) fn extend_from_decoder, I: Iterator for _ in values_iter.by_ref().take(valids) {}, diff --git a/crates/polars-parquet/src/arrow/read/indexes/mod.rs b/crates/polars-parquet/src/arrow/read/indexes/mod.rs index 0a8184bc27238..a9a48a98e8f49 100644 --- a/crates/polars-parquet/src/arrow/read/indexes/mod.rs +++ b/crates/polars-parquet/src/arrow/read/indexes/mod.rs @@ -184,7 +184,9 @@ fn deserialize( PhysicalType::Binary | PhysicalType::LargeBinary | PhysicalType::Utf8 - | PhysicalType::LargeUtf8 => { + | PhysicalType::LargeUtf8 + | PhysicalType::Utf8View + | PhysicalType::BinaryView => { let index = indexes .pop_front() .unwrap() diff --git a/crates/polars-parquet/src/arrow/read/schema/convert.rs b/crates/polars-parquet/src/arrow/read/schema/convert.rs index 5eeaa94a1355e..2089e261188fa 100644 --- a/crates/polars-parquet/src/arrow/read/schema/convert.rs +++ b/crates/polars-parquet/src/arrow/read/schema/convert.rs @@ -150,15 +150,15 @@ fn from_byte_array( converted_type: &Option, ) -> ArrowDataType { match (logical_type, converted_type) { - (Some(PrimitiveLogicalType::String), _) => ArrowDataType::LargeUtf8, - (Some(PrimitiveLogicalType::Json), _) => ArrowDataType::LargeBinary, - (Some(PrimitiveLogicalType::Bson), _) => ArrowDataType::LargeBinary, - (Some(PrimitiveLogicalType::Enum), _) => ArrowDataType::LargeBinary, - (_, Some(PrimitiveConvertedType::Json)) => ArrowDataType::LargeBinary, - (_, Some(PrimitiveConvertedType::Bson)) => ArrowDataType::LargeBinary, - (_, Some(PrimitiveConvertedType::Enum)) => ArrowDataType::LargeBinary, - (_, Some(PrimitiveConvertedType::Utf8)) => ArrowDataType::LargeUtf8, - (_, _) => ArrowDataType::LargeBinary, + (Some(PrimitiveLogicalType::String), _) => ArrowDataType::Utf8View, + (Some(PrimitiveLogicalType::Json), _) => ArrowDataType::BinaryView, + (Some(PrimitiveLogicalType::Bson), _) => ArrowDataType::BinaryView, + (Some(PrimitiveLogicalType::Enum), _) => ArrowDataType::BinaryView, + (_, Some(PrimitiveConvertedType::Json)) => ArrowDataType::BinaryView, + (_, Some(PrimitiveConvertedType::Bson)) => ArrowDataType::BinaryView, + (_, Some(PrimitiveConvertedType::Enum)) => ArrowDataType::BinaryView, + (_, Some(PrimitiveConvertedType::Utf8)) => ArrowDataType::Utf8View, + (_, _) => ArrowDataType::BinaryView, } } @@ -407,7 +407,6 @@ pub(crate) fn to_data_type( #[cfg(test)] mod tests { - use arrow::datatypes::{ArrowDataType, Field, TimeUnit}; use polars_error::*; use super::*; @@ -440,8 +439,8 @@ mod tests { Field::new("int64", ArrowDataType::Int64, false), Field::new("double", ArrowDataType::Float64, true), Field::new("float", ArrowDataType::Float32, true), - Field::new("string", ArrowDataType::LargeUtf8, true), - Field::new("string_2", ArrowDataType::LargeUtf8, true), + Field::new("string", ArrowDataType::Utf8View, true), + Field::new("string_2", ArrowDataType::Utf8View, true), ]; let parquet_schema = SchemaDescriptor::try_from_message(message)?; @@ -460,7 +459,7 @@ mod tests { } "; let expected = vec![ - Field::new("binary", ArrowDataType::LargeBinary, false), + Field::new("binary", ArrowDataType::BinaryView, false), Field::new("fixed_binary", ArrowDataType::FixedSizeBinary(20), false), ]; @@ -733,7 +732,7 @@ mod tests { { let struct_fields = vec![ - Field::new("event_name", ArrowDataType::LargeUtf8, false), + Field::new("event_name", ArrowDataType::Utf8View, false), Field::new( "event_time", ArrowDataType::Timestamp(TimeUnit::Millisecond, Some("+00:00".into())), @@ -793,7 +792,7 @@ mod tests { "my_list1", ArrowDataType::LargeList(Box::new(Field::new( "element", - ArrowDataType::LargeUtf8, + ArrowDataType::Utf8View, true, ))), false, @@ -811,7 +810,7 @@ mod tests { "my_list2", ArrowDataType::LargeList(Box::new(Field::new( "element", - ArrowDataType::LargeUtf8, + ArrowDataType::Utf8View, false, ))), true, @@ -829,7 +828,7 @@ mod tests { "my_list3", ArrowDataType::LargeList(Box::new(Field::new( "element", - ArrowDataType::LargeUtf8, + ArrowDataType::Utf8View, false, ))), false, @@ -1059,7 +1058,7 @@ mod tests { Field::new("int64", ArrowDataType::Int64, false), Field::new("double", ArrowDataType::Float64, true), Field::new("float", ArrowDataType::Float32, true), - Field::new("string", ArrowDataType::LargeUtf8, true), + Field::new("string", ArrowDataType::Utf8View, true), Field::new( "bools", ArrowDataType::LargeList(Box::new(Field::new( @@ -1116,7 +1115,7 @@ mod tests { ]), false, ), - Field::new("dictionary_strings", ArrowDataType::LargeUtf8, false), + Field::new("dictionary_strings", ArrowDataType::Utf8View, false), ]; let parquet_schema = SchemaDescriptor::try_from_message(message_type)?; diff --git a/crates/polars-parquet/src/arrow/read/statistics/list.rs b/crates/polars-parquet/src/arrow/read/statistics/list.rs index 10f966be2f350..34908cab5ecc1 100644 --- a/crates/polars-parquet/src/arrow/read/statistics/list.rs +++ b/crates/polars-parquet/src/arrow/read/statistics/list.rs @@ -14,7 +14,9 @@ pub struct DynMutableListArray { impl DynMutableListArray { pub fn try_with_capacity(data_type: ArrowDataType, capacity: usize) -> PolarsResult { let inner = match data_type.to_logical_type() { - ArrowDataType::List(inner) | ArrowDataType::LargeList(inner) => inner.data_type(), + ArrowDataType::List(inner) + | ArrowDataType::LargeList(inner) + | ArrowDataType::FixedSizeList(inner, _) => inner.data_type(), _ => unreachable!(), }; let inner = make_mutable(inner, capacity)?; @@ -60,6 +62,11 @@ impl MutableArray for DynMutableListArray { None, )) }, + ArrowDataType::FixedSizeList(field, _) => Box::new(FixedSizeListArray::new( + ArrowDataType::FixedSizeList(field.clone(), inner.len()), + inner, + None, + )), _ => unreachable!(), } } diff --git a/crates/polars-parquet/src/arrow/read/statistics/mod.rs b/crates/polars-parquet/src/arrow/read/statistics/mod.rs index 465e428ba0eeb..32fc15d93e16d 100644 --- a/crates/polars-parquet/src/arrow/read/statistics/mod.rs +++ b/crates/polars-parquet/src/arrow/read/statistics/mod.rs @@ -62,90 +62,76 @@ struct MutableStatistics { impl From for Statistics { fn from(mut s: MutableStatistics) -> Self { - let null_count = if let PhysicalType::Struct = s.null_count.data_type().to_physical_type() { - s.null_count + let null_count = match s.null_count.data_type().to_physical_type() { + PhysicalType::Struct => s + .null_count .as_box() .as_any() .downcast_ref::() .unwrap() .clone() - .boxed() - } else if let PhysicalType::Map = s.null_count.data_type().to_physical_type() { - s.null_count - .as_box() - .as_any() - .downcast_ref::() - .unwrap() - .clone() - .boxed() - } else if let PhysicalType::List = s.null_count.data_type().to_physical_type() { - s.null_count + .boxed(), + PhysicalType::List => s + .null_count .as_box() .as_any() .downcast_ref::>() .unwrap() .clone() - .boxed() - } else if let PhysicalType::LargeList = s.null_count.data_type().to_physical_type() { - s.null_count + .boxed(), + PhysicalType::LargeList => s + .null_count .as_box() .as_any() .downcast_ref::>() .unwrap() .clone() - .boxed() - } else { - s.null_count + .boxed(), + _ => s + .null_count .as_box() .as_any() .downcast_ref::() .unwrap() .clone() - .boxed() + .boxed(), }; - let distinct_count = if let PhysicalType::Struct = - s.distinct_count.data_type().to_physical_type() - { - s.distinct_count + + let distinct_count = match s.distinct_count.data_type().to_physical_type() { + PhysicalType::Struct => s + .distinct_count .as_box() .as_any() .downcast_ref::() .unwrap() .clone() - .boxed() - } else if let PhysicalType::Map = s.distinct_count.data_type().to_physical_type() { - s.distinct_count - .as_box() - .as_any() - .downcast_ref::() - .unwrap() - .clone() - .boxed() - } else if let PhysicalType::List = s.distinct_count.data_type().to_physical_type() { - s.distinct_count + .boxed(), + PhysicalType::List => s + .distinct_count .as_box() .as_any() .downcast_ref::>() .unwrap() .clone() - .boxed() - } else if let PhysicalType::LargeList = s.distinct_count.data_type().to_physical_type() { - s.distinct_count + .boxed(), + PhysicalType::LargeList => s + .distinct_count .as_box() .as_any() .downcast_ref::>() .unwrap() .clone() - .boxed() - } else { - s.distinct_count + .boxed(), + _ => s + .distinct_count .as_box() .as_any() .downcast_ref::() .unwrap() .clone() - .boxed() + .boxed(), }; + Self { null_count, distinct_count, @@ -180,9 +166,10 @@ fn make_mutable(data_type: &ArrowDataType, capacity: usize) -> PolarsResult Box::new( + PhysicalType::LargeList | PhysicalType::List | PhysicalType::FixedSizeList => Box::new( DynMutableListArray::try_with_capacity(data_type.clone(), capacity)?, - ) as Box, + ) + as Box, PhysicalType::Dictionary(_) => Box::new( dictionary::DynMutableDictionary::try_with_capacity(data_type.clone(), capacity)?, ), @@ -212,32 +199,27 @@ fn make_mutable(data_type: &ArrowDataType, capacity: usize) -> PolarsResult ArrowDataType { - if let ArrowDataType::Struct(fields) = data_type.to_logical_type() { - ArrowDataType::Struct( + match data_type.to_logical_type() { + ArrowDataType::Struct(fields) => ArrowDataType::Struct( fields .iter() .map(|f| Field::new(&f.name, create_dt(&f.data_type), f.is_nullable)) .collect(), - ) - } else if let ArrowDataType::Map(f, ordered) = data_type.to_logical_type() { - ArrowDataType::Map( + ), + ArrowDataType::Map(f, ordered) => ArrowDataType::Map( Box::new(Field::new(&f.name, create_dt(&f.data_type), f.is_nullable)), *ordered, - ) - } else if let ArrowDataType::List(f) = data_type.to_logical_type() { - ArrowDataType::List(Box::new(Field::new( - &f.name, - create_dt(&f.data_type), - f.is_nullable, - ))) - } else if let ArrowDataType::LargeList(f) = data_type.to_logical_type() { - ArrowDataType::LargeList(Box::new(Field::new( + ), + ArrowDataType::LargeList(f) => ArrowDataType::LargeList(Box::new(Field::new( &f.name, create_dt(&f.data_type), f.is_nullable, - ))) - } else { - ArrowDataType::UInt64 + ))), + // FixedSizeList piggy backs on list + ArrowDataType::List(f) | ArrowDataType::FixedSizeList(f, _) => ArrowDataType::List( + Box::new(Field::new(&f.name, create_dt(&f.data_type), f.is_nullable)), + ), + _ => ArrowDataType::UInt64, } } @@ -330,7 +312,7 @@ fn push( null_count: &mut dyn MutableArray, ) -> PolarsResult<()> { match min.data_type().to_logical_type() { - List(_) | LargeList(_) => { + List(_) | LargeList(_) | FixedSizeList(_, _) => { let min = min .as_mut_any() .downcast_mut::() diff --git a/crates/polars-parquet/src/arrow/write/dictionary.rs b/crates/polars-parquet/src/arrow/write/dictionary.rs index cfc5ad888a840..b3ea666865c9c 100644 --- a/crates/polars-parquet/src/arrow/write/dictionary.rs +++ b/crates/polars-parquet/src/arrow/write/dictionary.rs @@ -1,4 +1,4 @@ -use arrow::array::{Array, DictionaryArray, DictionaryKey, Utf8ViewArray}; +use arrow::array::{Array, BinaryViewArray, DictionaryArray, DictionaryKey, Utf8ViewArray}; use arrow::bitmap::{Bitmap, MutableBitmap}; use arrow::datatypes::{ArrowDataType, IntegerType}; use num_traits::ToPrimitive; @@ -149,13 +149,11 @@ fn serialize_levels( fn normalized_validity(array: &DictionaryArray) -> Option { match (array.keys().validity(), array.values().validity()) { (None, None) => None, - (None, rhs) => rhs.cloned(), - (lhs, None) => lhs.cloned(), - (Some(_), Some(rhs)) => { - let projected_validity = array - .keys_iter() - .map(|x| x.map(|x| rhs.get_bit(x)).unwrap_or(false)); - MutableBitmap::from_trusted_len_iter(projected_validity).into() + (keys, None) => keys.cloned(), + // The values can have a different length than the keys + (_, Some(_values)) => { + let iter = (0..array.len()).map(|i| unsafe { !array.is_null_unchecked(i) }); + MutableBitmap::from_trusted_len_iter(iter).into() }, } } @@ -169,9 +167,6 @@ fn serialize_keys( ) -> PolarsResult { let mut buffer = vec![]; - // parquet only accepts a single validity - we "&" the validities into a single one - // and ignore keys whole _value_ is null. - let validity = normalized_validity(array); let (start, len) = slice_nested_leaf(nested); let mut nested = nested.to_vec(); @@ -181,6 +176,10 @@ fn serialize_keys( } else { unreachable!("") } + // Parquet only accepts a single validity - we "&" the validities into a single one + // and ignore keys whose _value_ is null. + // It's important that we slice before normalizing. + let validity = normalized_validity(&array); let (repetition_levels_byte_length, definition_levels_byte_length) = serialize_levels( validity.as_ref(), @@ -242,7 +241,7 @@ pub fn array_to_pages( match encoding { Encoding::PlainDictionary | Encoding::RleDictionary => { // write DictPage - let (dict_page, statistics): (_, Option) = + let (dict_page, mut statistics): (_, Option) = match array.values().data_type().to_logical_type() { ArrowDataType::Int8 => dyn_prim!(i8, i32, array, options, type_), ArrowDataType::Int16 => dyn_prim!(i16, i32, array, options, type_), @@ -278,6 +277,22 @@ pub fn array_to_pages( }; (DictPage::new(buffer, array.len(), false), stats) }, + ArrowDataType::BinaryView => { + let array = array + .values() + .as_any() + .downcast_ref::() + .unwrap(); + let mut buffer = vec![]; + binview::encode_plain(array, &mut buffer); + + let stats = if options.write_statistics { + Some(binview::build_statistics(array, type_.clone())) + } else { + None + }; + (DictPage::new(buffer, array.len(), false), stats) + }, ArrowDataType::Utf8View => { let array = array .values() @@ -301,9 +316,7 @@ pub fn array_to_pages( let mut buffer = vec![]; binary_encode_plain::(values, &mut buffer); let stats = if options.write_statistics { - let mut stats = binary_build_statistics(values, type_.clone()); - stats.null_count = Some(array.null_count() as i64); - Some(stats) + Some(binary_build_statistics(values, type_.clone())) } else { None }; @@ -314,8 +327,7 @@ pub fn array_to_pages( let array = array.values().as_any().downcast_ref().unwrap(); fixed_binary_encode_plain(array, false, &mut buffer); let stats = if options.write_statistics { - let mut stats = fixed_binary_build_statistics(array, type_.clone()); - stats.null_count = Some(array.null_count() as i64); + let stats = fixed_binary_build_statistics(array, type_.clone()); Some(serialize_statistics(&stats)) } else { None @@ -329,6 +341,10 @@ pub fn array_to_pages( }, }; + if let Some(stats) = &mut statistics { + stats.null_count = Some(array.null_count() as i64) + } + // write DataPage pointing to DictPage let data_page = serialize_keys(array, type_, nested, statistics, options)?.unwrap_data(); diff --git a/crates/polars-parquet/src/arrow/write/mod.rs b/crates/polars-parquet/src/arrow/write/mod.rs index e9cb74279162c..d992a3f08a8ec 100644 --- a/crates/polars-parquet/src/arrow/write/mod.rs +++ b/crates/polars-parquet/src/arrow/write/mod.rs @@ -92,8 +92,9 @@ pub fn slice_nested_leaf(nested: &[Nested]) -> (usize, usize) { let end = *l_nested.offsets.last(); return (start as usize, (end - start) as usize); }, + Nested::FixedSizeList { len, width, .. } => return (0, *len * *width), Nested::Primitive(_, _, len) => out = (0, *len), - _ => {}, + Nested::Struct(_, _, _) => {}, } } out @@ -135,6 +136,7 @@ pub fn slice_parquet_array( validity.slice(current_offset, current_length) }; + // Update the offset/ length so that the Primitive is sliced properly. current_length = l_nested.offsets.range() as usize; current_offset = *l_nested.offsets.first() as usize; }, @@ -144,6 +146,7 @@ pub fn slice_parquet_array( validity.slice(current_offset, current_length) }; + // Update the offset/ length so that the Primitive is sliced properly. current_length = l_nested.offsets.range() as usize; current_offset = *l_nested.offsets.first() as usize; }, @@ -160,6 +163,20 @@ pub fn slice_parquet_array( }; primitive_array.slice(current_offset, current_length); }, + Nested::FixedSizeList { + validity, + len, + width, + .. + } => { + if let Some(validity) = validity.as_mut() { + validity.slice(current_offset, current_length) + }; + *len = current_length; + // Update the offset/ length so that the Primitive is sliced properly. + current_length *= *width; + current_offset *= *width; + }, } } } @@ -171,6 +188,7 @@ pub fn get_max_length(nested: &[Nested]) -> usize { match nested { Nested::LargeList(l_nested) => length += l_nested.offsets.range() as usize, Nested::List(l_nested) => length += l_nested.offsets.range() as usize, + Nested::FixedSizeList { len, width, .. } => length += *len * *width, _ => {}, } } diff --git a/crates/polars-parquet/src/arrow/write/nested/def.rs b/crates/polars-parquet/src/arrow/write/nested/def.rs index c2497205d8625..4b4c21767cf42 100644 --- a/crates/polars-parquet/src/arrow/write/nested/def.rs +++ b/crates/polars-parquet/src/arrow/write/nested/def.rs @@ -1,5 +1,6 @@ use arrow::bitmap::Bitmap; use arrow::offset::Offset; +use polars_utils::slice::GetSaferUnchecked; use super::super::pages::{ListNested, Nested}; use super::rep::num_values; @@ -50,6 +51,34 @@ fn single_list_iter<'a, O: Offset>(nested: &'a ListNested) -> Box( + width: usize, + is_optional: bool, + validity: Option<&'a Bitmap>, + len: usize, +) -> Box { + let lengths = std::iter::repeat(width).take(len); + match (is_optional, validity) { + (false, _) => Box::new( + std::iter::repeat(0u32) + .zip(lengths) + .map(|(a, b)| (a + (b != 0) as u32, b)), + ) as Box, + (true, None) => Box::new( + std::iter::repeat(1u32) + .zip(lengths) + .map(|(a, b)| (a + (b != 0) as u32, b)), + ) as Box, + (true, Some(validity)) => Box::new( + validity + .iter() + .map(|x| (x as u32)) + .zip(lengths) + .map(|(a, b)| (a + (b != 0) as u32, b)), + ) as Box, + } +} + fn iter<'a>(nested: &'a [Nested]) -> Vec> { nested .iter() @@ -62,6 +91,13 @@ fn iter<'a>(nested: &'a [Nested]) -> Vec> { Nested::Struct(validity, is_optional, length) => { single_iter(validity, *is_optional, *length) }, + Nested::FixedSizeList { + validity, + is_optional, + len, + width, + .. + } => single_fixed_list_iter(*width, *is_optional, validity.as_ref(), *len), }) .collect() } @@ -150,15 +186,17 @@ impl<'a> Iterator for DefLevelsIter<'a> { let r = Some(self.total + empty_contrib); for index in (1..self.current_level).rev() { - if self.remaining[index] == 0 { - self.current_level -= 1; - self.remaining[index - 1] -= 1; - self.total -= self.validity[index]; + unsafe { + if *self.remaining.get_unchecked_release(index) == 0 { + self.current_level -= 1; + *self.remaining.get_unchecked_release_mut(index - 1) -= 1; + self.total -= *self.validity.get_unchecked_release(index); + } } } if self.remaining[0] == 0 { self.current_level = self.current_level.saturating_sub(1); - self.total -= self.validity[0]; + self.total -= unsafe { self.validity.get_unchecked_release(0) }; } self.remaining_values -= 1; r diff --git a/crates/polars-parquet/src/arrow/write/nested/mod.rs b/crates/polars-parquet/src/arrow/write/nested/mod.rs index c53d266255c53..46e15eec6c729 100644 --- a/crates/polars-parquet/src/arrow/write/nested/mod.rs +++ b/crates/polars-parquet/src/arrow/write/nested/mod.rs @@ -80,6 +80,7 @@ fn max_def_level(nested: &[Nested]) -> usize { Nested::List(nested) => 1 + (nested.is_optional as usize), Nested::LargeList(nested) => 1 + (nested.is_optional as usize), Nested::Struct(_, is_optional, _) => *is_optional as usize, + Nested::FixedSizeList { is_optional, .. } => *is_optional as usize, }) .sum() } @@ -88,7 +89,7 @@ fn max_rep_level(nested: &[Nested]) -> usize { nested .iter() .map(|nested| match nested { - Nested::LargeList(_) | Nested::List(_) => 1, + Nested::FixedSizeList { .. } | Nested::LargeList(_) | Nested::List(_) => 1, Nested::Primitive(_, _, _) | Nested::Struct(_, _, _) => 0, }) .sum() diff --git a/crates/polars-parquet/src/arrow/write/nested/rep.rs b/crates/polars-parquet/src/arrow/write/nested/rep.rs index 2bfbe1ce24f4c..52d73ded7b512 100644 --- a/crates/polars-parquet/src/arrow/write/nested/rep.rs +++ b/crates/polars-parquet/src/arrow/write/nested/rep.rs @@ -1,3 +1,5 @@ +use polars_utils::slice::GetSaferUnchecked; + use super::super::pages::Nested; use super::to_length; @@ -16,6 +18,9 @@ fn iter<'a>(nested: &'a [Nested]) -> Vec> { Nested::LargeList(nested) => { Some(Box::new(to_length(&nested.offsets)) as Box) }, + Nested::FixedSizeList { width, len, .. } => { + Some(Box::new(std::iter::repeat(*width).take(*len)) as Box) + }, Nested::Struct(_, _, _) => None, }) .collect() @@ -25,7 +30,7 @@ fn iter<'a>(nested: &'a [Nested]) -> Vec> { pub fn num_values(nested: &[Nested]) -> usize { let pr = match nested.last().unwrap() { Nested::Primitive(_, _, len) => *len, - _ => todo!(), + _ => unreachable!(), }; iter(nested) @@ -113,14 +118,16 @@ impl<'a> Iterator for RepLevelsIter<'a> { let r = Some((self.current_level - self.total) as u32); // update - for index in (1..self.current_level).rev() { - if self.remaining[index] == 0 { - self.current_level -= 1; - self.remaining[index - 1] -= 1; + unsafe { + for index in (1..self.current_level).rev() { + if *self.remaining.get_unchecked_release(index) == 0 { + self.current_level -= 1; + *self.remaining.get_unchecked_release_mut(index - 1) -= 1; + } + } + if *self.remaining.get_unchecked_release(0) == 0 { + self.current_level = self.current_level.saturating_sub(1); } - } - if self.remaining[0] == 0 { - self.current_level = self.current_level.saturating_sub(1); } self.total = 0; self.remaining_values -= 1; diff --git a/crates/polars-parquet/src/arrow/write/pages.rs b/crates/polars-parquet/src/arrow/write/pages.rs index 95b6e91d2b3d9..f62735258205f 100644 --- a/crates/polars-parquet/src/arrow/write/pages.rs +++ b/crates/polars-parquet/src/arrow/write/pages.rs @@ -1,6 +1,6 @@ use std::fmt::Debug; -use arrow::array::{Array, ListArray, MapArray, StructArray}; +use arrow::array::{Array, FixedSizeListArray, ListArray, MapArray, StructArray}; use arrow::bitmap::Bitmap; use arrow::datatypes::PhysicalType; use arrow::offset::{Offset, OffsetsBuffer}; @@ -41,6 +41,13 @@ pub enum Nested { List(ListNested), /// a list LargeList(ListNested), + /// Width + FixedSizeList { + validity: Option, + is_optional: bool, + width: usize, + len: usize, + }, /// a struct /// - validity /// - is_optional @@ -56,6 +63,7 @@ impl Nested { Nested::List(nested) => nested.offsets.len_proxy(), Nested::LargeList(nested) => nested.offsets.len_proxy(), Nested::Struct(_, _, len) => *len, + Nested::FixedSizeList { len, .. } => *len, } } } @@ -98,6 +106,30 @@ fn to_nested_recursive( to_nested_recursive(array.as_ref(), type_, nested, parents.clone())?; } }, + FixedSizeList => { + let array = array.as_any().downcast_ref::().unwrap(); + let type_ = if let ParquetType::GroupType { fields, .. } = type_ { + if let ParquetType::GroupType { fields, .. } = &fields[0] { + &fields[0] + } else { + polars_bail!(InvalidOperation: + "Parquet type must be a group for a list array".to_string(), + ) + } + } else { + polars_bail!(InvalidOperation: + "Parquet type must be a group for a list array".to_string(), + ) + }; + + parents.push(Nested::FixedSizeList { + validity: array.validity().cloned(), + len: array.len(), + width: array.size(), + is_optional, + }); + to_nested_recursive(array.values().as_ref(), type_, nested, parents)?; + }, List => { let array = array.as_any().downcast_ref::>().unwrap(); let type_ = if let ParquetType::GroupType { fields, .. } = type_ { @@ -204,6 +236,10 @@ fn to_leaves_recursive<'a>(array: &'a dyn Array, leaves: &mut Vec<&'a dyn Array> let array = array.as_any().downcast_ref::>().unwrap(); to_leaves_recursive(array.values().as_ref(), leaves); }, + FixedSizeList => { + let array = array.as_any().downcast_ref::().unwrap(); + to_leaves_recursive(array.values().as_ref(), leaves); + }, Map => { let array = array.as_any().downcast_ref::().unwrap(); to_leaves_recursive(array.field().as_ref(), leaves); @@ -262,10 +298,9 @@ pub fn array_to_columns + Send + Sync>( #[cfg(test)] mod tests { use arrow::array::*; - use arrow::bitmap::Bitmap; use arrow::datatypes::*; - use super::super::{FieldInfo, ParquetPhysicalType, ParquetPrimitiveType}; + use super::super::{FieldInfo, ParquetPhysicalType}; use super::*; use crate::parquet::schema::types::{ GroupLogicalType, PrimitiveConvertedType, PrimitiveLogicalType, diff --git a/crates/polars-parquet/src/parquet/bloom_filter/split_block.rs b/crates/polars-parquet/src/parquet/bloom_filter/split_block.rs index 576f4d5f1aba4..e8672648cda46 100644 --- a/crates/polars-parquet/src/parquet/bloom_filter/split_block.rs +++ b/crates/polars-parquet/src/parquet/bloom_filter/split_block.rs @@ -1,5 +1,3 @@ -use std::convert::TryInto; - /// magic numbers taken from https://github.com/apache/parquet-format/blob/master/BloomFilter.md const SALT: [u32; 8] = [ 1203114875, 1150766481, 2284105051, 2729912477, 1884591559, 770785867, 2667333959, 1550580529, diff --git a/crates/polars-parquet/src/parquet/compression.rs b/crates/polars-parquet/src/parquet/compression.rs index c33ea01b5fdb2..3e638eeb05c70 100644 --- a/crates/polars-parquet/src/parquet/compression.rs +++ b/crates/polars-parquet/src/parquet/compression.rs @@ -340,13 +340,6 @@ mod tests { ))); } - #[test] - fn test_codec_gzip_high_compression() { - test_codec(CompressionOptions::Gzip(Some( - GzipLevel::try_new(10).unwrap(), - ))); - } - #[test] fn test_codec_brotli_default() { test_codec(CompressionOptions::Brotli(None)); diff --git a/crates/polars-parquet/src/parquet/deserialize/filtered_rle.rs b/crates/polars-parquet/src/parquet/deserialize/filtered_rle.rs index 57c95c77b401c..7549b7de37384 100644 --- a/crates/polars-parquet/src/parquet/deserialize/filtered_rle.rs +++ b/crates/polars-parquet/src/parquet/deserialize/filtered_rle.rs @@ -88,7 +88,7 @@ impl<'a, I: Iterator, Error>>> FilteredHybridBit } /// Returns the number of elements remaining. Note that each run - /// of the iterator contains more than one element - this is is _not_ equivalent to size_hint. + /// of the iterator contains more than one element - this is _not_ equivalent to size_hint. pub fn len(&self) -> usize { self.total_items } diff --git a/crates/polars-parquet/src/parquet/deserialize/hybrid_rle.rs b/crates/polars-parquet/src/parquet/deserialize/hybrid_rle.rs index 746dd27b330d2..4ceab84b850ee 100644 --- a/crates/polars-parquet/src/parquet/deserialize/hybrid_rle.rs +++ b/crates/polars-parquet/src/parquet/deserialize/hybrid_rle.rs @@ -1,3 +1,5 @@ +use polars_utils::iter::FallibleIterator; + use crate::parquet::encoding::hybrid_rle::{self, BitmapIter}; use crate::parquet::error::Error; @@ -139,55 +141,70 @@ where { iter: I, current_run: Option>, + result: Result<(), Error>, } impl<'a, I> HybridRleBooleanIter<'a, I> where - I: Iterator, Error>>, + I: HybridRleRunsIterator<'a>, { pub fn new(iter: I) -> Self { Self { iter, current_run: None, + result: Ok(()), } } + + fn set_new_run(&mut self, run: Result, Error>) -> Option { + let run = match run { + Err(e) => { + self.result = Err(e); + return None; + }, + Ok(r) => r, + }; + + let run = match run { + HybridEncoded::Bitmap(bitmap, length) => { + HybridBooleanState::Bitmap(BitmapIter::new(bitmap, 0, length)) + }, + HybridEncoded::Repeated(value, length) => HybridBooleanState::Repeated(value, length), + }; + self.current_run = Some(run); + self.next() + } } impl<'a, I> Iterator for HybridRleBooleanIter<'a, I> where I: HybridRleRunsIterator<'a>, { - type Item = Result; + type Item = bool; #[inline] fn next(&mut self) -> Option { if let Some(run) = &mut self.current_run { match run { - HybridBooleanState::Bitmap(bitmap) => bitmap.next().map(Ok), - HybridBooleanState::Repeated(value, remaining) => if *remaining == 0 { - None - } else { - *remaining -= 1; - Some(*value) - } - .map(Ok), - } - } else if let Some(run) = self.iter.next() { - let run = run.map(|run| match run { - HybridEncoded::Bitmap(bitmap, length) => { - HybridBooleanState::Bitmap(BitmapIter::new(bitmap, 0, length)) - }, - HybridEncoded::Repeated(value, length) => { - HybridBooleanState::Repeated(value, length) + HybridBooleanState::Bitmap(bitmap) => match bitmap.next() { + Some(val) => Some(val), + None => { + let run = self.iter.next()?; + self.set_new_run(run) + }, }, - }); - match run { - Ok(run) => { - self.current_run = Some(run); - self.next() + HybridBooleanState::Repeated(value, remaining) => { + if *remaining == 0 { + let run = self.iter.next()?; + self.set_new_run(run) + } else { + *remaining -= 1; + Some(*value) + } }, - Err(e) => Some(Err(e)), } + } else if let Some(run) = self.iter.next() { + self.set_new_run(run) } else { None } @@ -200,5 +217,14 @@ where } } +impl<'a, I> FallibleIterator for HybridRleBooleanIter<'a, I> +where + I: HybridRleRunsIterator<'a>, +{ + fn get_result(&mut self) -> Result<(), Error> { + self.result.clone() + } +} + /// Type definition for a [`HybridRleBooleanIter`] using [`hybrid_rle::Decoder`]. pub type HybridRleDecoderIter<'a> = HybridRleBooleanIter<'a, HybridDecoderBitmapIter<'a>>; diff --git a/crates/polars-parquet/src/parquet/deserialize/utils.rs b/crates/polars-parquet/src/parquet/deserialize/utils.rs index 0c89d09d46484..da0c251244ba2 100644 --- a/crates/polars-parquet/src/parquet/deserialize/utils.rs +++ b/crates/polars-parquet/src/parquet/deserialize/utils.rs @@ -25,6 +25,7 @@ pub(super) fn dict_indices_decoder(page: &DataPage) -> Result { /// When the maximum definition level is 1, the definition levels are RLE-encoded and /// the bitpacked runs are bitmaps. This variant contains [`HybridDecoderBitmapIter`] @@ -57,27 +58,25 @@ impl<'a> DefLevelsDecoder<'a> { /// Iterator adapter to convert an iterator of non-null values and an iterator over validity /// into an iterator of optional values. #[derive(Debug, Clone)] -pub struct OptionalValues>, I: Iterator> { +pub struct OptionalValues, I: Iterator> { validity: V, values: I, } -impl>, I: Iterator> OptionalValues { +impl, I: Iterator> OptionalValues { pub fn new(validity: V, values: I) -> Self { Self { validity, values } } } -impl>, I: Iterator> Iterator - for OptionalValues -{ - type Item = Result, Error>; +impl, I: Iterator> Iterator for OptionalValues { + type Item = Option; #[inline] fn next(&mut self) -> Option { self.validity .next() - .map(|x| x.map(|x| if x { self.values.next() } else { None })) + .map(|x| if x { self.values.next() } else { None }) } #[inline] @@ -93,7 +92,7 @@ impl>, I: Iterator> Iterator /// allows this iterator to skip sequences of items without having to call each of them. #[derive(Debug, Clone)] pub struct SliceFilteredIter { - iter: I, + pub(crate) iter: I, selected_rows: VecDeque, current_remaining: usize, current: usize, // position in the slice @@ -146,8 +145,6 @@ impl> Iterator for SliceFilteredIter { #[cfg(test)] mod test { - use std::collections::VecDeque; - use super::*; #[test] diff --git a/crates/polars-parquet/src/parquet/encoding/bitpacked/encode.rs b/crates/polars-parquet/src/parquet/encoding/bitpacked/encode.rs index 904ff796dd341..cc4e62ebdd330 100644 --- a/crates/polars-parquet/src/parquet/encoding/bitpacked/encode.rs +++ b/crates/polars-parquet/src/parquet/encoding/bitpacked/encode.rs @@ -1,5 +1,3 @@ -use std::convert::TryInto; - use super::{Packed, Unpackable, Unpacked}; /// Encodes (packs) a slice of [`Unpackable`] into bitpacked bytes `packed`, using `num_bits` per value. diff --git a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/decoder.rs b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/decoder.rs index d000b918efb73..3a867aa6b1bc2 100644 --- a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/decoder.rs +++ b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/decoder.rs @@ -29,22 +29,16 @@ impl<'a> Iterator for Decoder<'a> { #[inline] // -18% improvement in bench fn next(&mut self) -> Option { - if self.num_bits == 0 { - return None; - } - - if self.values.is_empty() { - return None; - } - let (indicator, consumed) = match uleb128::decode(self.values) { Ok((indicator, consumed)) => (indicator, consumed), Err(e) => return Some(Err(e)), }; self.values = unsafe { self.values.get_unchecked_release(consumed..) }; - if self.values.is_empty() { + + // We want to early return if consumed == 0 OR num_bits == 0, so combine into a single branch. + if (consumed * self.num_bits) == 0 { return None; - }; + } if indicator & 1 == 1 { // is bitpacking diff --git a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/encoder.rs b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/encoder.rs index e2e381bc4b90f..1c4dd67ccec7c 100644 --- a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/encoder.rs +++ b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/encoder.rs @@ -40,7 +40,8 @@ fn bitpacked_encode_u32>( let remainder = length - chunks * U32_BLOCK_LEN; let mut buffer = [0u32; U32_BLOCK_LEN]; - let compressed_chunk_size = ceil8(U32_BLOCK_LEN * num_bits); + // simplified from ceil8(U32_BLOCK_LEN * num_bits) since U32_BLOCK_LEN = 32 + let compressed_chunk_size = 4 * num_bits; for _ in 0..chunks { iterator @@ -58,6 +59,9 @@ fn bitpacked_encode_u32>( // Must be careful here to ensure we write a multiple of `num_bits` // (the bit width) to align with the spec. Some readers also rely on // this - see https://github.com/pola-rs/polars/pull/13883. + + // this is ceil8(remainder * num_bits), but we ensure the output is a + // multiple of num_bits by rewriting it as ceil8(remainder) * num_bits let compressed_remainder_size = ceil8(remainder) * num_bits; iterator .by_ref() diff --git a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/mod.rs b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/mod.rs index eed96ca431139..3dc0725525240 100644 --- a/crates/polars-parquet/src/parquet/encoding/hybrid_rle/mod.rs +++ b/crates/polars-parquet/src/parquet/encoding/hybrid_rle/mod.rs @@ -5,6 +5,7 @@ mod encoder; pub use bitmap::{encode_bool as bitpacked_encode, BitmapIter}; pub use decoder::Decoder; pub use encoder::{encode_bool, encode_u32}; +use polars_utils::iter::FallibleIterator; use super::bitpacked; use crate::parquet::error::Error; @@ -35,6 +36,7 @@ pub struct HybridRleDecoder<'a> { decoder: Decoder<'a>, state: State<'a>, remaining: usize, + result: Result<(), Error>, } #[inline] @@ -72,12 +74,13 @@ impl<'a> HybridRleDecoder<'a> { decoder, state, remaining: num_values, + result: Ok(()), }) } } impl<'a> Iterator for HybridRleDecoder<'a> { - type Item = Result; + type Item = u32; fn next(&mut self) -> Option { if self.remaining == 0 { @@ -96,12 +99,15 @@ impl<'a> Iterator for HybridRleDecoder<'a> { State::None => Some(0), } { self.remaining -= 1; - return Some(Ok(result)); + return Some(result); } self.state = match read_next(&mut self.decoder, self.remaining) { Ok(state) => state, - Err(e) => return Some(Err(e)), + Err(e) => { + self.result = Err(e); + return None; + }, } } } @@ -111,6 +117,13 @@ impl<'a> Iterator for HybridRleDecoder<'a> { } } +impl<'a> FallibleIterator for HybridRleDecoder<'a> { + #[inline] + fn get_result(&mut self) -> Result<(), Error> { + std::mem::replace(&mut self.result, Ok(())) + } +} + impl<'a> ExactSizeIterator for HybridRleDecoder<'a> {} #[cfg(test)] @@ -128,7 +141,7 @@ mod tests { let decoder = HybridRleDecoder::try_new(&buffer, num_bits, data.len())?; - let result = decoder.collect::, _>>()?; + let result = decoder.collect::>(); assert_eq!(result, data); Ok(()) @@ -212,7 +225,7 @@ mod tests { let decoder = HybridRleDecoder::try_new(&data, num_bits, 1000)?; - let result = decoder.collect::, _>>()?; + let result = decoder.collect::>(); assert_eq!(result, (0..1000).collect::>()); Ok(()) @@ -226,7 +239,7 @@ mod tests { let decoder = HybridRleDecoder::try_new(&data, num_bits, 1)?; - let result = decoder.collect::, _>>()?; + let result = decoder.collect::>(); assert_eq!(result, &[2]); Ok(()) @@ -240,7 +253,7 @@ mod tests { let decoder = HybridRleDecoder::try_new(&data, num_bits, 2)?; - let result = decoder.collect::, _>>()?; + let result = decoder.collect::>(); assert_eq!(result, &[0, 0]); Ok(()) @@ -254,7 +267,7 @@ mod tests { let decoder = HybridRleDecoder::try_new(&data, num_bits, 100)?; - let result = decoder.collect::, _>>()?; + let result = decoder.collect::>(); assert_eq!(result, vec![0; 100]); Ok(()) diff --git a/crates/polars-parquet/src/parquet/encoding/mod.rs b/crates/polars-parquet/src/parquet/encoding/mod.rs index 79b608ab63b7e..81d751a3e004c 100644 --- a/crates/polars-parquet/src/parquet/encoding/mod.rs +++ b/crates/polars-parquet/src/parquet/encoding/mod.rs @@ -1,5 +1,3 @@ -use std::convert::TryInto; - pub mod bitpacked; pub mod delta_bitpacked; pub mod delta_byte_array; diff --git a/crates/polars-parquet/src/parquet/metadata/schema_descriptor.rs b/crates/polars-parquet/src/parquet/metadata/schema_descriptor.rs index a5a3a7b107358..e685e4147f178 100644 --- a/crates/polars-parquet/src/parquet/metadata/schema_descriptor.rs +++ b/crates/polars-parquet/src/parquet/metadata/schema_descriptor.rs @@ -38,7 +38,7 @@ impl SchemaDescriptor { } } - /// The [`ColumnDescriptor`] (leafs) of this schema. + /// The [`ColumnDescriptor`] (leaves) of this schema. /// /// Note that, for nested fields, this may contain more entries than the number of fields /// in the file - e.g. a struct field may have two columns. diff --git a/crates/polars-parquet/src/parquet/mod.rs b/crates/polars-parquet/src/parquet/mod.rs index 3ba6d8b0cadcc..e54600fb4af5f 100644 --- a/crates/polars-parquet/src/parquet/mod.rs +++ b/crates/polars-parquet/src/parquet/mod.rs @@ -18,9 +18,9 @@ pub mod write; use parquet_format_safe as thrift_format; pub use streaming_decompression::{fallible_streaming_iterator, FallibleStreamingIterator}; -const HEADER_SIZE: u64 = PARQUET_MAGIC.len() as u64; -const FOOTER_SIZE: u64 = 8; -const PARQUET_MAGIC: [u8; 4] = [b'P', b'A', b'R', b'1']; +pub const HEADER_SIZE: u64 = PARQUET_MAGIC.len() as u64; +pub const FOOTER_SIZE: u64 = 8; +pub const PARQUET_MAGIC: [u8; 4] = [b'P', b'A', b'R', b'1']; /// The number of bytes read at the end of the parquet file on first read const DEFAULT_FOOTER_READ_SIZE: u64 = 64 * 1024; diff --git a/crates/polars-parquet/src/parquet/parquet_bridge.rs b/crates/polars-parquet/src/parquet/parquet_bridge.rs index eec75e4994caa..e3851d211be8f 100644 --- a/crates/polars-parquet/src/parquet/parquet_bridge.rs +++ b/crates/polars-parquet/src/parquet/parquet_bridge.rs @@ -1,5 +1,4 @@ // Bridges structs from thrift-generated code to rust enums. -use std::convert::TryFrom; #[cfg(feature = "serde_types")] use serde::{Deserialize, Serialize}; diff --git a/crates/polars-parquet/src/parquet/read/compression.rs b/crates/polars-parquet/src/parquet/read/compression.rs index d5ea2e8f400e3..3366a5c56c66d 100644 --- a/crates/polars-parquet/src/parquet/read/compression.rs +++ b/crates/polars-parquet/src/parquet/read/compression.rs @@ -1,5 +1,4 @@ use parquet_format_safe::DataPageHeaderV2; -use streaming_decompression; use super::page::PageIterator; use crate::parquet::compression::{self, Compression}; diff --git a/crates/polars-parquet/src/parquet/read/indexes/read.rs b/crates/polars-parquet/src/parquet/read/indexes/read.rs index 379fb41507666..9572ccf177233 100644 --- a/crates/polars-parquet/src/parquet/read/indexes/read.rs +++ b/crates/polars-parquet/src/parquet/read/indexes/read.rs @@ -1,4 +1,3 @@ -use std::convert::TryInto; use std::io::{Cursor, Read, Seek, SeekFrom}; use parquet_format_safe::thrift::protocol::TCompactInputProtocol; diff --git a/crates/polars-parquet/src/parquet/read/metadata.rs b/crates/polars-parquet/src/parquet/read/metadata.rs index a75b939a513c2..10864e194aebb 100644 --- a/crates/polars-parquet/src/parquet/read/metadata.rs +++ b/crates/polars-parquet/src/parquet/read/metadata.rs @@ -1,5 +1,4 @@ use std::cmp::min; -use std::convert::TryInto; use std::io::{Read, Seek, SeekFrom}; use parquet_format_safe::thrift::protocol::TCompactInputProtocol; diff --git a/crates/polars-parquet/src/parquet/read/page/reader.rs b/crates/polars-parquet/src/parquet/read/page/reader.rs index e0078f97c6d4a..0f1c7d0fb0f32 100644 --- a/crates/polars-parquet/src/parquet/read/page/reader.rs +++ b/crates/polars-parquet/src/parquet/read/page/reader.rs @@ -1,4 +1,3 @@ -use std::convert::TryInto; use std::io::Read; use std::sync::Arc; diff --git a/crates/polars-parquet/src/parquet/schema/io_message/from_message.rs b/crates/polars-parquet/src/parquet/schema/io_message/from_message.rs index 10f3f0614dce6..a27a0b9a57a88 100644 --- a/crates/polars-parquet/src/parquet/schema/io_message/from_message.rs +++ b/crates/polars-parquet/src/parquet/schema/io_message/from_message.rs @@ -619,10 +619,10 @@ impl<'a> Parser<'a> { #[cfg(test)] mod tests { - use types::{IntegerType, PrimitiveLogicalType}; + use types::IntegerType; use super::*; - use crate::parquet::schema::types::{GroupConvertedType, PhysicalType, PrimitiveConvertedType}; + use crate::parquet::schema::types::PhysicalType; #[test] fn test_tokenize_empty_string() { diff --git a/crates/polars-parquet/src/parquet/types.rs b/crates/polars-parquet/src/parquet/types.rs index f2e7b1472eb37..b9d93a91bd261 100644 --- a/crates/polars-parquet/src/parquet/types.rs +++ b/crates/polars-parquet/src/parquet/types.rs @@ -1,5 +1,3 @@ -use std::convert::TryFrom; - use crate::parquet::schema::types::PhysicalType; /// A physical native representation of a Parquet fixed-sized type. diff --git a/crates/polars-parquet/src/parquet/write/page.rs b/crates/polars-parquet/src/parquet/write/page.rs index 1f024b629f078..ad6bc32efc685 100644 --- a/crates/polars-parquet/src/parquet/write/page.rs +++ b/crates/polars-parquet/src/parquet/write/page.rs @@ -1,4 +1,3 @@ -use std::convert::TryInto; use std::io::Write; use std::sync::Arc; diff --git a/crates/polars-parquet/tests/it/main.rs b/crates/polars-parquet/tests/it/main.rs deleted file mode 100644 index e108117e793af..0000000000000 --- a/crates/polars-parquet/tests/it/main.rs +++ /dev/null @@ -1 +0,0 @@ -mod roundtrip; diff --git a/crates/polars-pipe/Cargo.toml b/crates/polars-pipe/Cargo.toml index 8f480871ac5b6..7c9615ba9d094 100644 --- a/crates/polars-pipe/Cargo.toml +++ b/crates/polars-pipe/Cargo.toml @@ -38,7 +38,7 @@ cloud = ["async", "polars-io/cloud", "polars-plan/cloud", "tokio", "futures"] parquet = ["polars-plan/parquet", "polars-io/parquet", "polars-io/async"] ipc = ["polars-plan/ipc", "polars-io/ipc"] json = ["polars-plan/json", "polars-io/json"] -async = ["polars-plan/async", "polars-io/async"] +async = ["polars-plan/async", "polars-io/async", "futures"] nightly = ["polars-core/nightly", "polars-utils/nightly", "hashbrown/nightly"] cross_join = ["polars-ops/cross_join"] dtype-u8 = ["polars-core/dtype-u8"] diff --git a/crates/polars-pipe/src/executors/operators/pass.rs b/crates/polars-pipe/src/executors/operators/pass.rs index c80e0fbe61045..6b2189dbc2b8f 100644 --- a/crates/polars-pipe/src/executors/operators/pass.rs +++ b/crates/polars-pipe/src/executors/operators/pass.rs @@ -27,6 +27,6 @@ impl Operator for Pass { } fn fmt(&self) -> &str { - "pass" + self.name } } diff --git a/crates/polars-pipe/src/executors/operators/placeholder.rs b/crates/polars-pipe/src/executors/operators/placeholder.rs index 13987828cedab..8a12be07d12b6 100644 --- a/crates/polars-pipe/src/executors/operators/placeholder.rs +++ b/crates/polars-pipe/src/executors/operators/placeholder.rs @@ -1,9 +1,75 @@ +use std::sync::{Arc, Mutex}; + use polars_core::error::PolarsResult; use crate::operators::{DataChunk, Operator, OperatorResult, PExecutionContext}; -#[derive(Default)] -pub struct PlaceHolder {} +#[derive(Clone)] +struct CallBack { + inner: Arc>>>, +} + +impl CallBack { + fn new() -> Self { + Self { + inner: Default::default(), + } + } + + fn replace(&self, op: Box) { + let mut lock = self.inner.try_lock().expect("no-contention"); + *lock = Some(op); + } +} + +impl Operator for CallBack { + fn execute( + &mut self, + context: &PExecutionContext, + chunk: &DataChunk, + ) -> PolarsResult { + let mut lock = self.inner.try_lock().expect("no-contention"); + lock.as_mut().unwrap().execute(context, chunk) + } + + fn flush(&mut self) -> PolarsResult { + let mut lock = self.inner.try_lock().expect("no-contention"); + lock.as_mut().unwrap().flush() + } + + fn must_flush(&self) -> bool { + let lock = self.inner.try_lock().expect("no-contention"); + lock.as_ref().unwrap().must_flush() + } + + fn split(&self, _thread_no: usize) -> Box { + panic!("should not be called") + } + + fn fmt(&self) -> &str { + "callback" + } +} + +#[derive(Clone, Default)] +pub struct PlaceHolder { + inner: Arc>>, +} + +impl PlaceHolder { + pub fn new() -> Self { + Self { + inner: Arc::new(Default::default()), + } + } + + pub fn replace(&self, op: Box) { + let inner = self.inner.lock().unwrap(); + for (thread_no, cb) in inner.iter() { + cb.replace(op.split(*thread_no)) + } + } +} impl Operator for PlaceHolder { fn execute( @@ -14,8 +80,11 @@ impl Operator for PlaceHolder { panic!("placeholder should be replaced") } - fn split(&self, _thread_no: usize) -> Box { - Box::new(Self {}) + fn split(&self, thread_no: usize) -> Box { + let cb = CallBack::new(); + let mut inner = self.inner.lock().unwrap(); + inner.push((thread_no, cb.clone())); + Box::new(cb) } fn fmt(&self) -> &str { diff --git a/crates/polars-pipe/src/executors/operators/projection.rs b/crates/polars-pipe/src/executors/operators/projection.rs index 1501da49d5fa0..f1271457417c9 100644 --- a/crates/polars-pipe/src/executors/operators/projection.rs +++ b/crates/polars-pipe/src/executors/operators/projection.rs @@ -98,7 +98,7 @@ impl Operator for ProjectionOperator { } } - let chunk = chunk.with_data(DataFrame::new_no_checks(projected)); + let chunk = chunk.with_data(unsafe { DataFrame::new_no_checks(projected) }); Ok(OperatorResult::Finished(chunk)) } fn split(&self, _thread_no: usize) -> Box { @@ -149,7 +149,8 @@ impl Operator for HstackOperator { .map(|e| e.evaluate(chunk, context.execution_state.as_any())) .collect::>>()?; - let mut df = DataFrame::new_no_checks(chunk.data.get_columns()[..width].to_vec()); + let columns = chunk.data.get_columns()[..width].to_vec(); + let mut df = unsafe { DataFrame::new_no_checks(columns) }; let schema = &*self.input_schema; if self.unchecked { diff --git a/crates/polars-pipe/src/executors/operators/reproject.rs b/crates/polars-pipe/src/executors/operators/reproject.rs index 3383759696361..ca2bd5cb1e78f 100644 --- a/crates/polars-pipe/src/executors/operators/reproject.rs +++ b/crates/polars-pipe/src/executors/operators/reproject.rs @@ -1,26 +1,8 @@ +use polars_core::error::PolarsResult; use polars_core::frame::DataFrame; -use polars_core::prelude::SchemaRef; use polars_core::schema::Schema; -use crate::operators::{DataChunk, Operator, OperatorResult, PExecutionContext, PolarsResult}; - -/// An operator that will ensure we keep the schema order -pub(crate) struct ReProjectOperator { - schema: SchemaRef, - operator: Box, - // cache the positions - positions: Vec, -} - -impl ReProjectOperator { - pub(crate) fn new(schema: SchemaRef, operator: Box) -> Self { - ReProjectOperator { - schema, - operator, - positions: vec![], - } - } -} +use crate::operators::DataChunk; pub(crate) fn reproject_chunk( chunk: &mut DataChunk, @@ -45,76 +27,8 @@ pub(crate) fn reproject_chunk( } else { let columns = chunk.data.get_columns(); let cols = positions.iter().map(|i| columns[*i].clone()).collect(); - DataFrame::new_no_checks(cols) + unsafe { DataFrame::new_no_checks(cols) } }; *chunk = chunk.with_data(out); Ok(()) } - -impl Operator for ReProjectOperator { - fn execute( - &mut self, - context: &PExecutionContext, - chunk: &DataChunk, - ) -> PolarsResult { - let (mut chunk, finished) = match self.operator.execute(context, chunk)? { - OperatorResult::Finished(chunk) => (chunk, true), - OperatorResult::HaveMoreOutPut(chunk) => (chunk, false), - OperatorResult::NeedsNewData => return Ok(OperatorResult::NeedsNewData), - }; - reproject_chunk(&mut chunk, &mut self.positions, self.schema.as_ref())?; - Ok(if finished { - OperatorResult::Finished(chunk) - } else { - OperatorResult::HaveMoreOutPut(chunk) - }) - } - - fn split(&self, thread_no: usize) -> Box { - let operator = self.operator.split(thread_no); - Box::new(Self { - schema: self.schema.clone(), - positions: self.positions.clone(), - operator, - }) - } - - fn fmt(&self) -> &str { - "re-project-operator" - } -} - -#[cfg(test)] -mod test { - use polars_core::prelude::*; - - use super::*; - - #[test] - fn test_reproject_chunk() { - let df = df![ - "a" => [1, 2], - "b" => [1, 2], - "c" => [1, 2], - "d" => [1, 2], - ] - .unwrap(); - - let mut chunk1 = DataChunk::new(0, df.clone()); - let mut chunk2 = DataChunk::new(1, df); - - let mut positions = vec![]; - - let mut out_schema = Schema::new(); - out_schema.with_column("c".into(), DataType::Int32); - out_schema.with_column("b".into(), DataType::Int32); - out_schema.with_column("d".into(), DataType::Int32); - out_schema.with_column("a".into(), DataType::Int32); - - reproject_chunk(&mut chunk1, &mut positions, &out_schema).unwrap(); - // second call cached the positions - reproject_chunk(&mut chunk2, &mut positions, &out_schema).unwrap(); - assert_eq!(&chunk1.data.schema(), &out_schema); - assert_eq!(&chunk2.data.schema(), &out_schema); - } -} diff --git a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/convert.rs b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/convert.rs index 21a53de6db747..757963b91a68c 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/convert.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/convert.rs @@ -131,7 +131,7 @@ where AExpr::Len => ( IDX_DTYPE, Arc::new(Len {}), - AggregateFunction::Count(CountAgg::new()), + AggregateFunction::Len(CountAgg::new()), ), AExpr::Agg(agg) => match agg { AAggExpr::Min { input, .. } => { diff --git a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/count.rs b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/count.rs index 8180cc1159298..0e7581cfd5e9d 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/count.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/count.rs @@ -7,26 +7,28 @@ use polars_utils::unwrap::UnwrapUncheckedRelease; use super::*; use crate::operators::IdxSize; -pub(crate) struct CountAgg { +pub(crate) struct CountAgg { count: IdxSize, } -impl CountAgg { +impl CountAgg { pub(crate) fn new() -> Self { CountAgg { count: 0 } } - fn incr(&mut self) { - self.count += 1; - } } -impl AggregateFn for CountAgg { +impl AggregateFn for CountAgg { fn has_physical_agg(&self) -> bool { false } - fn pre_agg(&mut self, _chunk_idx: IdxSize, _item: &mut dyn ExactSizeIterator) { - self.incr(); + fn pre_agg(&mut self, _chunk_idx: IdxSize, item: &mut dyn ExactSizeIterator) { + let item = unsafe { item.next().unwrap_unchecked_release() }; + if INCLUDE_NULL { + self.count += 1; + } else { + self.count += !matches!(item, AnyValue::Null) as IdxSize; + } } fn pre_agg_ordered( &mut self, diff --git a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/first.rs b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/first.rs index 604502902d53b..60b50b144aa04 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/first.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/first.rs @@ -55,7 +55,7 @@ impl AggregateFn for FirstAgg { fn combine(&mut self, other: &dyn Any) { let other = unsafe { other.downcast_ref::().unwrap_unchecked_release() }; if other.first.is_some() && other.chunk_idx < self.chunk_idx { - self.first = other.first.clone(); + self.first.clone_from(&other.first); self.chunk_idx = other.chunk_idx; }; } diff --git a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/interface.rs b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/interface.rs index 8ea888858fa18..2496ac21add31 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/interface.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/interface.rs @@ -46,7 +46,8 @@ pub(crate) trait AggregateFn: Send + Sync { pub(crate) enum AggregateFunction { First(FirstAgg), Last(LastAgg), - Count(CountAgg), + Count(CountAgg), + Len(CountAgg), SumF32(SumAgg), SumF64(SumAgg), SumU32(SumAgg), @@ -83,6 +84,7 @@ impl AggregateFunction { MeanF32(_) => MeanF32(MeanAgg::new()), MeanF64(_) => MeanF64(MeanAgg::new()), Count(_) => Count(CountAgg::new()), + Len(_) => Len(CountAgg::new()), Null(a) => Null(a.clone()), MinMaxF32(inner) => MinMaxF32(inner.split()), MinMaxF64(inner) => MinMaxF64(inner.split()), diff --git a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/last.rs b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/last.rs index 08f211359064f..2a659d1aea01f 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/last.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/last.rs @@ -52,7 +52,7 @@ impl AggregateFn for LastAgg { fn combine(&mut self, other: &dyn Any) { let other = unsafe { other.downcast_ref::().unwrap_unchecked_release() }; if other.last.is_some() && other.chunk_idx >= self.chunk_idx { - self.last = other.last.clone(); + self.last.clone_from(&other.last); self.chunk_idx = other.chunk_idx; }; } diff --git a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/mean.rs b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/mean.rs index 8e07ca2aa4d29..82afe6a0b40c5 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/mean.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/mean.rs @@ -11,7 +11,6 @@ use polars_core::utils::arrow::compute::aggregate::sum_primitive; use polars_utils::unwrap::UnwrapUncheckedRelease; use super::*; -use crate::operators::{ArrowDataType, IdxSize}; pub struct MeanAgg { sum: Option, diff --git a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/min_max.rs b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/min_max.rs index 8466031b6114b..341bb067635b7 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/min_max.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/min_max.rs @@ -2,14 +2,12 @@ use std::any::Any; use arrow::array::PrimitiveArray; use polars_compute::min_max::MinMaxKernel; -use polars_core::datatypes::{AnyValue, DataType}; use polars_core::export::num::NumCast; use polars_core::prelude::*; use polars_utils::min_max::MinMax; use polars_utils::unwrap::UnwrapUncheckedRelease; use super::*; -use crate::operators::{ArrowDataType, IdxSize}; pub(super) fn new_min() -> MinMaxAgg K> { MinMaxAgg::new(MinMax::min_ignore_nan, true) diff --git a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/sum.rs b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/sum.rs index 8f2fdf9636380..b256ca41720f3 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/aggregates/sum.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/aggregates/sum.rs @@ -10,7 +10,6 @@ use polars_core::utils::arrow::compute::aggregate::sum_primitive; use polars_utils::unwrap::UnwrapUncheckedRelease; use super::*; -use crate::operators::{ArrowDataType, IdxSize}; pub struct SumAgg { sum: Option, diff --git a/crates/polars-pipe/src/executors/sinks/group_by/generic/eval.rs b/crates/polars-pipe/src/executors/sinks/group_by/generic/eval.rs index cf00212422ccd..3fa0b384dd0c3 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/generic/eval.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/generic/eval.rs @@ -1,6 +1,5 @@ use std::cell::UnsafeCell; -use arrow::array::{ArrayRef, BinaryArray}; use polars_core::export::ahash::RandomState; use polars_row::{RowsEncoded, SortField}; diff --git a/crates/polars-pipe/src/executors/sinks/group_by/generic/global.rs b/crates/polars-pipe/src/executors/sinks/group_by/generic/global.rs index 5663775056ddc..4488a6faad820 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/generic/global.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/generic/global.rs @@ -1,6 +1,5 @@ use std::collections::LinkedList; use std::sync::atomic::{AtomicU16, Ordering}; -use std::sync::Mutex; use polars_core::utils::accumulate_dataframes_vertical_unchecked; use polars_core::POOL; diff --git a/crates/polars-pipe/src/executors/sinks/group_by/generic/hash_table.rs b/crates/polars-pipe/src/executors/sinks/group_by/generic/hash_table.rs index 703403cf5e8b7..43023493c7697 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/generic/hash_table.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/generic/hash_table.rs @@ -131,11 +131,11 @@ impl AggHashTable { pub(super) unsafe fn insert( &mut self, hash: u64, - row: &[u8], + key: &[u8], agg_iters: &mut [SeriesPhysIter], chunk_index: IdxSize, ) -> bool { - let agg_idx = match self.insert_key(hash, row) { + let agg_idx = match self.insert_key(hash, key) { // overflow None => return true, Some(agg_idx) => agg_idx, @@ -275,7 +275,7 @@ impl AggHashTable { ); cols.extend(agg_builders.into_iter().map(|buf| buf.into_series())); physical_agg_to_logical(&mut cols, &self.output_schema); - DataFrame::new_no_checks(cols) + unsafe { DataFrame::new_no_checks(cols) } } } diff --git a/crates/polars-pipe/src/executors/sinks/group_by/generic/mod.rs b/crates/polars-pipe/src/executors/sinks/group_by/generic/mod.rs index 16457b63bc0d5..41967ee854261 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/generic/mod.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/generic/mod.rs @@ -83,7 +83,7 @@ impl SpillPayload { cols.push(chunk_idx); cols.push(keys); cols.extend(self.aggs); - DataFrame::new_no_checks(cols) + unsafe { DataFrame::new_no_checks(cols) } } fn spilled_to_columns( diff --git a/crates/polars-pipe/src/executors/sinks/group_by/generic/ooc_state.rs b/crates/polars-pipe/src/executors/sinks/group_by/generic/ooc_state.rs index 1b8610c54251d..77a939c642906 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/generic/ooc_state.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/generic/ooc_state.rs @@ -1,7 +1,6 @@ use polars_core::config::verbose; use super::*; -use crate::executors::sinks::io::IOThread; use crate::executors::sinks::memory::MemTracker; use crate::pipeline::{morsels_per_sink, FORCE_OOC}; diff --git a/crates/polars-pipe/src/executors/sinks/group_by/generic/source.rs b/crates/polars-pipe/src/executors/sinks/group_by/generic/source.rs index bdb52235b3b79..4ebaf073525bd 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/generic/source.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/generic/source.rs @@ -4,7 +4,7 @@ use polars_io::SerReader; use super::*; use crate::executors::sinks::group_by::generic::global::GlobalTable; -use crate::executors::sinks::io::{block_thread_until_io_thread_done, IOThread}; +use crate::executors::sinks::io::block_thread_until_io_thread_done; use crate::operators::{Source, SourceResult}; use crate::pipeline::PARTITION_SIZE; diff --git a/crates/polars-pipe/src/executors/sinks/group_by/ooc.rs b/crates/polars-pipe/src/executors/sinks/group_by/ooc.rs index 1c57aad65cbd5..4d808b7265334 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/ooc.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/ooc.rs @@ -8,8 +8,8 @@ use crate::operators::{DataChunk, FinalizedSink, PExecutionContext, Sink, Source use crate::pipeline::{morsels_per_sink, PipeLine}; pub(super) struct GroupBySource { - // holding this keeps the lockfile in place - _io_thread: IOThread, + // Holding this keeps the lockfile in place + io_thread: IOThread, already_finished: Option, partitions: std::fs::ReadDir, group_by_sink: Box, @@ -34,7 +34,7 @@ impl GroupBySource { } Ok(Self { - _io_thread: io_thread, + io_thread, already_finished: Some(already_finished), partitions, group_by_sink, @@ -85,7 +85,7 @@ impl Source for GroupBySource { let mut pipe = PipeLine::new_simple(sources, vec![], self.group_by_sink.split(0), verbose()); - match pipe.run_pipeline(context, Default::default())?.unwrap() { + let out = match pipe.run_pipeline(context, &mut vec![])?.unwrap() { FinalizedSink::Finished(mut df) => { if let Some(slice) = &mut self.slice { let height = df.height(); @@ -118,7 +118,12 @@ impl Source for GroupBySource { // recursively out of core path FinalizedSink::Source(mut src) => src.get_batches(context), _ => unreachable!(), + }; + for path in files { + self.io_thread.clean(path) } + + out }, } } diff --git a/crates/polars-pipe/src/executors/sinks/group_by/primitive/mod.rs b/crates/polars-pipe/src/executors/sinks/group_by/primitive/mod.rs index 584bcb64f7af2..30fb437bd6bd8 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/primitive/mod.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/primitive/mod.rs @@ -209,7 +209,7 @@ where cols.push(key_builder.finish().into_series()); cols.extend(buffers.into_iter().map(|buf| buf.into_series())); physical_agg_to_logical(&mut cols, &self.output_schema); - Some(DataFrame::new_no_checks(cols)) + Some(unsafe { DataFrame::new_no_checks(cols) }) }) .collect::>(); Ok(dfs) diff --git a/crates/polars-pipe/src/executors/sinks/group_by/string.rs b/crates/polars-pipe/src/executors/sinks/group_by/string.rs index 8d164928105aa..9d8bbf6e55470 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/string.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/string.rs @@ -213,7 +213,7 @@ impl StringGroupbySink { cols.push(key_builder.finish().into_series()); cols.extend(buffers.into_iter().map(|buf| buf.into_series())); physical_agg_to_logical(&mut cols, &self.output_schema); - Some(DataFrame::new_no_checks(cols)) + Some(unsafe { DataFrame::new_no_checks(cols) }) }) .collect::>(); diff --git a/crates/polars-pipe/src/executors/sinks/io.rs b/crates/polars-pipe/src/executors/sinks/io.rs index 18e9206856656..ac2b27717a75e 100644 --- a/crates/polars-pipe/src/executors/sinks/io.rs +++ b/crates/polars-pipe/src/executors/sinks/io.rs @@ -2,10 +2,9 @@ use std::fs; use std::fs::File; use std::path::{Path, PathBuf}; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; use std::time::{Duration, SystemTime}; -use crossbeam_channel::{bounded, Sender}; +use crossbeam_channel::{bounded, unbounded, Receiver, Sender}; use polars_core::error::ErrString; use polars_core::prelude::*; use polars_core::utils::arrow::temporal_conversions::SECONDS_IN_DAY; @@ -21,7 +20,8 @@ type Payload = (Option, DfIter); /// A helper that can be used to spill to disk pub(crate) struct IOThread { - sender: Sender, + payload_tx: Sender, + cleanup_tx: Sender, _lockfile: Arc, pub(in crate::executors::sinks) dir: PathBuf, pub(in crate::executors::sinks) sent: Arc, @@ -58,10 +58,24 @@ fn get_spill_dir(operation_name: &'static str) -> PolarsResult { Ok(dir) } +fn clean_after_delay(time: Option, secs: u64, path: &Path) { + if let Some(time) = time { + let modified_since = SystemTime::now().duration_since(time).unwrap().as_secs(); + if modified_since > secs { + // This can be fallible if another thread removes this. + // That is fine. + let _ = std::fs::remove_dir_all(path); + } + } else { + polars_warn!("could not modified time on this platform") + } +} + /// Starts a new thread that will clean up operations of directories that don't /// have a lockfile (opened with 'w' permissions). -fn gc_thread(operation_name: &'static str) { +fn gc_thread(operation_name: &'static str, rx: Receiver) { let _ = std::thread::spawn(move || { + // First clean all existing let mut dir = std::path::PathBuf::from(get_base_temp_dir()); dir.push(&format!("polars/{operation_name}")); @@ -78,25 +92,35 @@ fn gc_thread(operation_name: &'static str) { if let Ok(lockfile) = File::open(lockfile_path) { // lockfile can be read - if let Ok(time) = lockfile.metadata().unwrap().modified() { - let modified_since = - SystemTime::now().duration_since(time).unwrap().as_secs(); - // the lockfile can still exist if a process was canceled + if let Ok(md) = lockfile.metadata() { + let time = md.modified().ok(); + // The lockfile can still exist if a process was canceled // so we also check the modified date - // we don't expect queries that run a month - if modified_since > (SECONDS_IN_DAY as u64 * 30) { - std::fs::remove_dir_all(path).unwrap() - } - } else { - eprintln!("could not modified time on this platform") + // we don't expect queries that run a month. + clean_after_delay(time, SECONDS_IN_DAY as u64 * 30, &path); } } else { - // This can be fallible as another Polars query could already have removed this. - // So we ignore the result. - let _ = std::fs::remove_dir_all(path); + // If path already removed, we simply continue. + if let Ok(md) = path.metadata() { + let time = md.modified().ok(); + // Wait 15 seconds to ensure we don't remove before lockfile is created + // in a `collect_all` contention case + clean_after_delay(time, 15, &path); + } } } } + + // Clean on receive + while let Ok(path) = rx.recv() { + if path.is_file() { + let res = std::fs::remove_file(path); + debug_assert!(res.is_ok()); + } else { + let res = std::fs::remove_dir_all(path); + debug_assert!(res.is_ok()); + } + } }); } @@ -113,12 +137,13 @@ impl IOThread { let lockfile_path = get_lockfile_path(&dir); let lockfile = Arc::new(LockFile::new(lockfile_path)?); + let (cleanup_tx, rx) = unbounded::(); // start a thread that will clean up old dumps. // TODO: if we will have more ooc in the future we will have a dedicated GC thread - gc_thread(operation_name); + gc_thread(operation_name, rx); // we need some pushback otherwise we still could go OOM. - let (sender, receiver) = bounded::(morsels_per_sink() * 2); + let (tx, rx) = bounded::(morsels_per_sink() * 2); let sent: Arc = Default::default(); let total: Arc = Default::default(); @@ -141,9 +166,10 @@ impl IOThread { // This will dump to `dir/count.ipc` // 2. (Some(partitions), DfIter) // This will dump to `dir/partition/count.ipc` - while let Ok((partitions, iter)) = receiver.recv() { + while let Ok((partitions, iter)) = rx.recv() { if let Some(partitions) = partitions { - for (part, df) in partitions.into_no_null_iter().zip(iter) { + for (part, mut df) in partitions.into_no_null_iter().zip(iter) { + df.shrink_to_fit(); let mut path = dir2.clone(); path.push(format!("{part}")); @@ -159,13 +185,14 @@ impl IOThread { } } else { let mut path = dir2.clone(); - path.push(format!("{count}.ipc")); + path.push(format!("{count}_0_pass.ipc")); let file = File::create(path).unwrap(); let writer = IpcWriter::new(file).with_pl_flavor(true); let mut writer = writer.batched(&schema).unwrap(); - for df in iter { + for mut df in iter { + df.shrink_to_fit(); writer.write_batch(&df).unwrap(); } writer.finish().unwrap(); @@ -177,7 +204,8 @@ impl IOThread { }); Ok(Self { - sender, + payload_tx: tx, + cleanup_tx, dir, sent, total, @@ -190,12 +218,13 @@ impl IOThread { pub(in crate::executors::sinks) fn dump_chunk(&self, mut df: DataFrame) { // if IO thread is blocked // we write locally on this thread - if self.sender.is_full() { + if self.payload_tx.is_full() { + df.shrink_to_fit(); let mut path = self.dir.clone(); let count = self.thread_local_count.fetch_add(1, Ordering::Relaxed); // thread local name we start with an underscore to ensure we don't get // duplicates - path.push(format!("_{count}.ipc")); + path.push(format!("_{count}_full.ipc")); let file = File::create(path).unwrap(); let mut writer = IpcWriter::new(file).with_pl_flavor(true); @@ -206,6 +235,10 @@ impl IOThread { } } + pub(in crate::executors::sinks) fn clean(&self, path: PathBuf) { + self.cleanup_tx.send(path).unwrap() + } + pub(in crate::executors::sinks) fn dump_partition(&self, partition_no: IdxSize, df: DataFrame) { let partition = Some(IdxCa::from_vec("", vec![partition_no])); let iter = Box::new(std::iter::once(df)); @@ -215,8 +248,9 @@ impl IOThread { pub(in crate::executors::sinks) fn dump_partition_local( &self, partition_no: IdxSize, - df: DataFrame, + mut df: DataFrame, ) { + df.shrink_to_fit(); let count = self.thread_local_count.fetch_add(1, Ordering::Relaxed); let mut path = self.dir.clone(); path.push(format!("{partition_no}")); @@ -234,7 +268,7 @@ impl IOThread { pub(in crate::executors::sinks) fn dump_iter(&self, partition: Option, iter: DfIter) { let add = iter.size_hint().1.unwrap(); - self.sender.send((partition, iter)).unwrap(); + self.payload_tx.send((partition, iter)).unwrap(); self.sent.fetch_add(add, Ordering::Relaxed); } } @@ -261,10 +295,11 @@ struct LockFile { impl LockFile { fn new(path: PathBuf) -> PolarsResult { - if File::create(&path).is_ok() { - Ok(Self { path }) - } else { - polars_bail!(ComputeError: "could not create lockfile") + match File::create(&path) { + Ok(_) => Ok(Self { path }), + Err(e) => { + polars_bail!(ComputeError: "could not create lockfile: {e}") + }, } } } diff --git a/crates/polars-pipe/src/executors/sinks/joins/cross.rs b/crates/polars-pipe/src/executors/sinks/joins/cross.rs index 08a29ebbec6c4..491cd5c36b500 100644 --- a/crates/polars-pipe/src/executors/sinks/joins/cross.rs +++ b/crates/polars-pipe/src/executors/sinks/joins/cross.rs @@ -7,8 +7,10 @@ use std::vec; use polars_core::error::PolarsResult; use polars_core::frame::DataFrame; use polars_ops::prelude::CrossJoin as CrossJoinTrait; +use polars_utils::arena::Node; use smartstring::alias::String as SmartString; +use crate::executors::operators::PlaceHolder; use crate::operators::{ chunks_to_df_unchecked, DataChunk, FinalizedSink, Operator, OperatorResult, PExecutionContext, Sink, SinkResult, @@ -19,19 +21,35 @@ pub struct CrossJoin { chunks: Vec, suffix: SmartString, swapped: bool, + node: Node, + placeholder: PlaceHolder, } impl CrossJoin { - pub(crate) fn new(suffix: SmartString, swapped: bool) -> Self { + pub(crate) fn new( + suffix: SmartString, + swapped: bool, + node: Node, + placeholder: PlaceHolder, + ) -> Self { CrossJoin { chunks: vec![], suffix, swapped, + node, + placeholder, } } } impl Sink for CrossJoin { + fn node(&self) -> Node { + self.node + } + fn is_join_build(&self) -> bool { + true + } + fn sink(&mut self, _context: &PExecutionContext, chunk: DataChunk) -> PolarsResult { self.chunks.push(chunk); Ok(SinkResult::CanHaveMoreInput) @@ -47,13 +65,13 @@ impl Sink for CrossJoin { Box::new(Self { suffix: self.suffix.clone(), swapped: self.swapped, + placeholder: self.placeholder.clone(), ..Default::default() }) } fn finalize(&mut self, _context: &PExecutionContext) -> PolarsResult { - // todo! share sink - Ok(FinalizedSink::Operator(Box::new(CrossJoinProbe { + let op = Box::new(CrossJoinProbe { df: Arc::new(chunks_to_df_unchecked(std::mem::take(&mut self.chunks))), suffix: Arc::from(self.suffix.as_ref()), in_process_left: None, @@ -61,7 +79,10 @@ impl Sink for CrossJoin { in_process_left_df: Default::default(), output_names: None, swapped: self.swapped, - }))) + }); + self.placeholder.replace(op); + + Ok(FinalizedSink::Operator) } fn as_any(&mut self) -> &mut dyn Any { diff --git a/crates/polars-pipe/src/executors/sinks/joins/generic_build.rs b/crates/polars-pipe/src/executors/sinks/joins/generic_build.rs index 34013f9624cac..864020d1a8a19 100644 --- a/crates/polars-pipe/src/executors/sinks/joins/generic_build.rs +++ b/crates/polars-pipe/src/executors/sinks/joins/generic_build.rs @@ -1,21 +1,19 @@ use std::any::Any; -use std::hash::{Hash, Hasher}; -use std::sync::Arc; -use arrow::array::{ArrayRef, BinaryArray}; +use arrow::array::BinaryArray; use hashbrown::hash_map::RawEntryMut; -use polars_core::error::PolarsResult; use polars_core::export::ahash::RandomState; use polars_core::prelude::*; use polars_core::utils::{_set_partition_size, accumulate_dataframes_vertical_unchecked}; -use polars_utils::hashing::hash_to_partition; -use polars_utils::idx_vec::UnitVec; -use polars_utils::index::ChunkId; +use polars_utils::arena::Node; use polars_utils::slice::GetSaferUnchecked; use polars_utils::unitvec; +use smartstring::alias::String as SmartString; use super::*; +use crate::executors::operators::PlaceHolder; use crate::executors::sinks::joins::generic_probe_inner_left::GenericJoinProbe; +use crate::executors::sinks::joins::generic_probe_outer::GenericOuterJoinProbe; use crate::executors::sinks::utils::{hash_rows, load_vec}; use crate::executors::sinks::HASHMAP_INIT_SIZE; use crate::expressions::PhysicalPipedExpr; @@ -24,33 +22,7 @@ use crate::operators::{DataChunk, FinalizedSink, PExecutionContext, Sink, SinkRe pub(super) type ChunkIdx = IdxSize; pub(super) type DfIdx = IdxSize; -// This is the hash and the Index offset in the chunks and the index offset in the dataframe -#[derive(Copy, Clone, Debug)] -pub(super) struct Key { - pub(super) hash: u64, - chunk_idx: IdxSize, - df_idx: IdxSize, -} - -impl Key { - #[inline] - fn new(hash: u64, chunk_idx: IdxSize, df_idx: IdxSize) -> Self { - Key { - hash, - chunk_idx, - df_idx, - } - } -} - -impl Hash for Key { - #[inline] - fn hash(&self, state: &mut H) { - state.write_u64(self.hash) - } -} - -pub struct GenericBuild { +pub struct GenericBuild { chunks: Vec, // the join columns are all tightly packed // the values of a join column(s) can be found @@ -64,7 +36,7 @@ pub struct GenericBuild { hb: RandomState, // partitioned tables that will be used for probing // stores the key and the chunk_idx, df_idx of the left table - hash_tables: Vec>>, + hash_tables: PartitionedMap, // the columns that will be joined on join_columns_left: Arc>>, @@ -77,9 +49,14 @@ pub struct GenericBuild { // the join order is swapped to ensure we hash the smaller table swapped: bool, join_nulls: bool, + node: Node, + key_names_left: Arc<[SmartString]>, + key_names_right: Arc<[SmartString]>, + placeholder: PlaceHolder, } -impl GenericBuild { +impl GenericBuild { + #[allow(clippy::too_many_arguments)] pub(crate) fn new( suffix: Arc, join_type: JoinType, @@ -87,10 +64,16 @@ impl GenericBuild { join_columns_left: Arc>>, join_columns_right: Arc>>, join_nulls: bool, + node: Node, + key_names_left: Arc<[SmartString]>, + key_names_right: Arc<[SmartString]>, + placeholder: PlaceHolder, ) -> Self { let hb: RandomState = Default::default(); let partitions = _set_partition_size(); - let hash_tables = load_vec(partitions, || PlIdHashMap::with_capacity(HASHMAP_INIT_SIZE)); + let hash_tables = PartitionedHashMap::new(load_vec(partitions, || { + PlIdHashMap::with_capacity(HASHMAP_INIT_SIZE) + })); GenericBuild { chunks: vec![], join_type, @@ -104,6 +87,10 @@ impl GenericBuild { hash_tables, hashes: vec![], join_nulls, + node, + key_names_left, + key_names_right, + placeholder, } } } @@ -121,8 +108,9 @@ pub(super) fn compare_fn( // as that has no indirection key_hash == h && { // we get the appropriate values from the join columns and compare it with the current row - let chunk_idx = key.chunk_idx as usize; - let df_idx = key.df_idx as usize; + let (chunk_idx, df_idx) = key.idx.extract(); + let chunk_idx = chunk_idx as usize; + let df_idx = df_idx as usize; // get the right columns from the linearly packed buffer let other_row = unsafe { @@ -134,7 +122,7 @@ pub(super) fn compare_fn( } } -impl GenericBuild { +impl GenericBuild { fn is_empty(&self) -> bool { match self.chunks.len() { 0 => true, @@ -165,7 +153,14 @@ impl GenericBuild { } } -impl Sink for GenericBuild { +impl Sink for GenericBuild { + fn node(&self) -> Node { + self.node + } + fn is_join_build(&self) -> bool { + true + } + fn sink(&mut self, context: &PExecutionContext, chunk: DataChunk) -> PolarsResult { // we do some juggling here so that we don't // end up with empty chunks @@ -190,11 +185,7 @@ impl Sink for GenericBuild { // row offset in the chunk belonging to the hash let mut current_df_idx = 0 as IdxSize; for (row, h) in rows.values_iter().zip(&self.hashes) { - // get the hashtable belonging to this hash partition - let partition = hash_to_partition(*h, self.hash_tables.len()); - let current_table = unsafe { self.hash_tables.get_unchecked_release_mut(partition) }; - - let entry = current_table.raw_entry_mut().from_hash(*h, |key| { + let entry = self.hash_tables.raw_entry_mut(*h).from_hash(*h, |key| { compare_fn(key, *h, &self.materialized_join_cols, row) }); @@ -202,10 +193,10 @@ impl Sink for GenericBuild { match entry { RawEntryMut::Vacant(entry) => { let key = Key::new(*h, current_chunk_offset, current_df_idx); - entry.insert(key, unitvec![payload]); + entry.insert(key, (unitvec![payload], Default::default())); }, RawEntryMut::Occupied(mut entry) => { - entry.get_mut().push(payload); + entry.get_mut().0.push(payload); }, }; @@ -241,12 +232,15 @@ impl Sink for GenericBuild { // we combine the other hashtable with ours, but we must offset the chunk_idx // values by the number of chunks we already got. self.hash_tables + .inner_mut() .iter_mut() - .zip(&other.hash_tables) + .zip(other.hash_tables.inner()) .for_each(|(ht, other_ht)| { for (k, val) in other_ht.iter() { + let val = &val.0; + let (chunk_idx, df_idx) = k.idx.extract(); // use the indexes to materialize the row - let other_row = unsafe { other.get_row(k.chunk_idx, k.df_idx) }; + let other_row = unsafe { other.get_row(chunk_idx, df_idx) }; let h = k.hash; let entry = ht.raw_entry_mut().from_hash(h, |key| { @@ -267,14 +261,14 @@ impl Sink for GenericBuild { }); payload.extend(iter); } - entry.insert(key, payload); + entry.insert(key, (payload, Default::default())); }, RawEntryMut::Occupied(mut entry) => { let iter = val.iter().map(|chunk_id| { let (chunk_idx, val_idx) = chunk_id.extract(); ChunkId::store(chunk_idx + chunks_offset, val_idx) }); - entry.get_mut().extend(iter); + entry.get_mut().0.extend(iter); }, } } @@ -289,37 +283,42 @@ impl Sink for GenericBuild { self.join_columns_left.clone(), self.join_columns_right.clone(), self.join_nulls, + self.node, + self.key_names_left.clone(), + self.key_names_right.clone(), + self.placeholder.clone(), ); new.hb = self.hb.clone(); Box::new(new) } fn finalize(&mut self, context: &PExecutionContext) -> PolarsResult { + let chunks_len = self.chunks.len(); + let left_df = accumulate_dataframes_vertical_unchecked( + std::mem::take(&mut self.chunks) + .into_iter() + .map(|chunk| chunk.data), + ); + if left_df.height() > 0 { + assert_eq!(left_df.n_chunks(), chunks_len); + } + // Reallocate to Arc<[]> to get rid of double indirection as this is accessed on every + // hashtable cmp. + let materialized_join_cols = Arc::from(std::mem::take(&mut self.materialized_join_cols)); + let suffix = self.suffix.clone(); + let hb = self.hb.clone(); + let hash_tables = Arc::new(PartitionedHashMap::new(std::mem::take( + self.hash_tables.inner_mut(), + ))); + let join_columns_left = self.join_columns_left.clone(); + let join_columns_right = self.join_columns_right.clone(); + + // take the buffers, this saves one allocation + let mut hashes = std::mem::take(&mut self.hashes); + hashes.clear(); + match self.join_type { JoinType::Inner | JoinType::Left => { - let chunks_len = self.chunks.len(); - let left_df = accumulate_dataframes_vertical_unchecked( - std::mem::take(&mut self.chunks) - .into_iter() - .map(|chunk| chunk.data), - ); - if left_df.height() > 0 { - assert_eq!(left_df.n_chunks(), chunks_len); - } - let materialized_join_cols = - Arc::new(std::mem::take(&mut self.materialized_join_cols)); - let suffix = self.suffix.clone(); - let hb = self.hb.clone(); - let hash_tables = Arc::new(std::mem::take(&mut self.hash_tables)); - let join_columns_left = self.join_columns_left.clone(); - let join_columns_right = self.join_columns_right.clone(); - - // take the buffers, this saves one allocation - let mut join_series = std::mem::take(&mut self.join_columns); - join_series.clear(); - let mut hashes = std::mem::take(&mut self.hashes); - hashes.clear(); - let probe_operator = GenericJoinProbe::new( left_df, materialized_join_cols, @@ -329,14 +328,33 @@ impl Sink for GenericBuild { join_columns_left, join_columns_right, self.swapped, - join_series, hashes, context, self.join_type.clone(), self.join_nulls, ); - Ok(FinalizedSink::Operator(Box::new(probe_operator))) + self.placeholder.replace(Box::new(probe_operator)); + Ok(FinalizedSink::Operator) + }, + JoinType::Outer { coalesce } => { + let probe_operator = GenericOuterJoinProbe::new( + left_df, + materialized_join_cols, + suffix, + hb, + hash_tables, + join_columns_left, + self.swapped, + hashes, + self.join_nulls, + coalesce, + self.key_names_left.clone(), + self.key_names_right.clone(), + ); + self.placeholder.replace(Box::new(probe_operator)); + Ok(FinalizedSink::Operator) }, + _ => unimplemented!(), } } diff --git a/crates/polars-pipe/src/executors/sinks/joins/generic_probe_inner_left.rs b/crates/polars-pipe/src/executors/sinks/joins/generic_probe_inner_left.rs index 6d1882d037706..19f63302dfc6e 100644 --- a/crates/polars-pipe/src/executors/sinks/joins/generic_probe_inner_left.rs +++ b/crates/polars-pipe/src/executors/sinks/joins/generic_probe_inner_left.rs @@ -1,101 +1,71 @@ use std::borrow::Cow; -use std::sync::Arc; -use arrow::array::{Array, ArrayRef, BinaryArray}; -use arrow::compute::utils::combine_validities_and; -use polars_core::error::PolarsResult; +use arrow::array::{Array, BinaryArray}; use polars_core::export::ahash::RandomState; use polars_core::prelude::*; use polars_core::series::IsSorted; use polars_ops::chunked_array::DfTake; use polars_ops::frame::join::_finish_join; use polars_ops::prelude::JoinType; -use polars_row::RowsEncoded; -use polars_utils::hashing::hash_to_partition; -use polars_utils::idx_vec::UnitVec; -use polars_utils::index::ChunkId; use polars_utils::nulls::IsNull; -use polars_utils::slice::GetSaferUnchecked; use smartstring::alias::String as SmartString; use crate::executors::sinks::joins::generic_build::*; +use crate::executors::sinks::joins::row_values::RowValues; +use crate::executors::sinks::joins::{ExtraPayload, PartitionedMap, ToRow}; use crate::executors::sinks::utils::hash_rows; use crate::expressions::PhysicalPipedExpr; use crate::operators::{DataChunk, Operator, OperatorResult, PExecutionContext}; #[derive(Clone)] -pub struct GenericJoinProbe { - // all chunks are stacked into a single dataframe - // the dataframe is not rechunked. +pub struct GenericJoinProbe { + /// All chunks are stacked into a single dataframe + /// the dataframe is not rechunked. df_a: Arc, - // the join columns are all tightly packed - // the values of a join column(s) can be found - // by: - // first get the offset of the chunks and multiply that with the number of join - // columns - // * chunk_offset = (idx * n_join_keys) - // * end = (offset + n_join_keys) - materialized_join_cols: Arc>>, + /// The join columns are all tightly packed + /// the values of a join column(s) can be found + /// by: + /// first get the offset of the chunks and multiply that with the number of join + /// columns + /// * chunk_offset = (idx * n_join_keys) + /// * end = (offset + n_join_keys) + materialized_join_cols: Arc<[BinaryArray]>, suffix: Arc, hb: RandomState, - // partitioned tables that will be used for probing - // stores the key and the chunk_idx, df_idx of the left table - hash_tables: Arc>>>, + /// partitioned tables that will be used for probing + /// stores the key and the chunk_idx, df_idx of the left table + hash_tables: Arc>, - // the columns that will be joined on - join_columns_right: Arc>>, - - // amortize allocations - current_rows: RowsEncoded, - join_columns: Vec, - // in inner join these are the left table - // in left join there are the right table + /// Amortize allocations + /// In inner join these are the left table. + /// In left join there are the right table. join_tuples_a: Vec, - join_tuples_a_left_join: Vec>, - // in inner join these are the right table - // in left join there are the left table + /// in inner join these are the right table + /// in left join there are the left table join_tuples_b: Vec, hashes: Vec, - // the join order is swapped to ensure we hash the smaller table + /// the join order is swapped to ensure we hash the smaller table swapped_or_left: bool, - // location of join columns. - // these column locations need to be dropped from the rhs - join_column_idx: Option>, - // cached output names + /// cached output names output_names: Option>, how: JoinType, join_nulls: bool, + row_values: RowValues, } -trait ToRow { - fn get_row(&self) -> &[u8]; -} - -impl ToRow for &[u8] { - fn get_row(&self) -> &[u8] { - self - } -} - -impl ToRow for Option<&[u8]> { - fn get_row(&self) -> &[u8] { - self.unwrap() - } -} - -impl GenericJoinProbe { +impl GenericJoinProbe { #[allow(clippy::too_many_arguments)] pub(super) fn new( mut df_a: DataFrame, - materialized_join_cols: Arc>>, + materialized_join_cols: Arc<[BinaryArray]>, suffix: Arc, hb: RandomState, - hash_tables: Arc>>>, + hash_tables: Arc>, join_columns_left: Arc>>, join_columns_right: Arc>>, swapped_or_left: bool, - join_columns: Vec, - hashes: Vec, + // Re-use the hashes allocation of the build side. + amortized_hashes: Vec, context: &PExecutionContext, how: JoinType, join_nulls: bool, @@ -125,67 +95,16 @@ impl GenericJoinProbe { suffix, hb, hash_tables, - join_columns_right, - join_columns, join_tuples_a: vec![], - join_tuples_a_left_join: vec![], join_tuples_b: vec![], - hashes, + hashes: amortized_hashes, swapped_or_left, - current_rows: Default::default(), - join_column_idx: None, output_names: None, how, join_nulls, + row_values: RowValues::new(join_columns_right, !swapped_or_left), } } - fn set_join_series( - &mut self, - context: &PExecutionContext, - chunk: &DataChunk, - ) -> PolarsResult> { - debug_assert!(self.join_columns.is_empty()); - - let determine_idx = !self.swapped_or_left && self.join_column_idx.is_none(); - let mut names = vec![]; - - for phys_e in self.join_columns_right.iter() { - let s = phys_e.evaluate(chunk, context.execution_state.as_any())?; - let s = s.to_physical_repr().rechunk(); - if determine_idx { - names.push(s.name().to_string()); - } - self.join_columns.push(s.array_ref(0).clone()); - } - - // we determine the indices of the columns that have to be removed - // if swapped the join column is already removed from the `build_df` as that will - // be the rhs one. - if !self.swapped_or_left && self.join_column_idx.is_none() { - let mut idx = names - .iter() - .filter_map(|name| chunk.data.get_column_index(name)) - .collect::>(); - // ensure that it is sorted so that we can later remove columns in - // a predictable order - idx.sort_unstable(); - self.join_column_idx = Some(idx); - } - polars_row::convert_columns_amortized_no_order(&self.join_columns, &mut self.current_rows); - - // SAFETY: we keep rows-encode alive - let array = unsafe { self.current_rows.borrow_array() }; - Ok(if self.join_nulls { - array - } else { - let validity = self - .join_columns - .iter() - .map(|arr| arr.validity().cloned()) - .fold(None, |l, r| combine_validities_and(l.as_ref(), r.as_ref())); - array.with_validity_typed(validity) - }) - } fn finish_join( &mut self, @@ -228,16 +147,12 @@ impl GenericJoinProbe { for (i, (h, row)) in iter { let df_idx_left = i as IdxSize; - // get the hashtable belonging by this hash partition - let partition = hash_to_partition(*h, self.hash_tables.len()); - let current_table = unsafe { self.hash_tables.get_unchecked_release(partition) }; - let entry = if row.is_null() { None } else { let row = row.get_row(); - current_table - .raw_entry() + self.hash_tables + .raw_entry(*h) .from_hash(*h, |key| { compare_fn(key, *h, &self.materialized_join_cols, row) }) @@ -246,14 +161,14 @@ impl GenericJoinProbe { match entry { Some(indexes_right) => { - self.join_tuples_a_left_join - .extend(indexes_right.iter().copied().map(Some)); + let indexes_right = &indexes_right.0; + self.join_tuples_a.extend_from_slice(indexes_right); self.join_tuples_b .extend(std::iter::repeat(df_idx_left).take(indexes_right.len())); }, None => { self.join_tuples_b.push(df_idx_left); - self.join_tuples_a_left_join.push(None); + self.join_tuples_a.push(ChunkId::null()); }, } } @@ -268,10 +183,12 @@ impl GenericJoinProbe { // and streams the left table through. This allows us to maintain // the left table order - self.join_tuples_a_left_join.clear(); + self.join_tuples_a.clear(); self.join_tuples_b.clear(); let mut hashes = std::mem::take(&mut self.hashes); - let rows = self.set_join_series(context, chunk)?; + let rows = self + .row_values + .get_values(context, chunk, self.join_nulls)?; hash_rows(&rows, &mut hashes, &self.hb); if self.join_nulls || rows.null_count() == 0 { @@ -291,13 +208,12 @@ impl GenericJoinProbe { .data ._take_unchecked_slice_sorted(&self.join_tuples_b, false, IsSorted::Ascending) }; - let right_df = - unsafe { right_df._take_opt_chunked_unchecked_seq(&self.join_tuples_a_left_join) }; + let right_df = unsafe { right_df._take_opt_chunked_unchecked_seq(&self.join_tuples_a) }; let out = self.finish_join(left_df, right_df)?; - // clear memory - self.join_columns.clear(); + // Clear memory. + self.row_values.clear(); self.hashes.clear(); Ok(OperatorResult::Finished(chunk.with_data(out))) @@ -309,18 +225,17 @@ impl GenericJoinProbe { { for (i, (h, row)) in iter { let df_idx_right = i as IdxSize; - // get the hashtable belonging by this hash partition - let partition = hash_to_partition(*h, self.hash_tables.len()); - let current_table = unsafe { self.hash_tables.get_unchecked_release(partition) }; - let entry = current_table - .raw_entry() + let entry = self + .hash_tables + .raw_entry(*h) .from_hash(*h, |key| { compare_fn(key, *h, &self.materialized_join_cols, row) }) .map(|key_val| key_val.1); if let Some(indexes_left) = entry { + let indexes_left = &indexes_left.0; self.join_tuples_a.extend_from_slice(indexes_left); self.join_tuples_b .extend(std::iter::repeat(df_idx_right).take(indexes_left.len())); @@ -336,7 +251,9 @@ impl GenericJoinProbe { self.join_tuples_a.clear(); self.join_tuples_b.clear(); let mut hashes = std::mem::take(&mut self.hashes); - let rows = self.set_join_series(context, chunk)?; + let rows = self + .row_values + .get_values(context, chunk, self.join_nulls)?; hash_rows(&rows, &mut hashes, &self.hb); if self.join_nulls || rows.null_count() == 0 { @@ -358,7 +275,7 @@ impl GenericJoinProbe { }; let right_df = unsafe { let mut df = Cow::Borrowed(&chunk.data); - if let Some(ids) = &self.join_column_idx { + if let Some(ids) = &self.row_values.join_column_idx { let mut tmp = df.into_owned(); let cols = tmp.get_columns_mut(); // we go from higher idx to lower so that lower indices remain untouched @@ -378,15 +295,15 @@ impl GenericJoinProbe { }; let out = self.finish_join(a, b)?; - // clear memory - self.join_columns.clear(); + // Clear memory. + self.row_values.clear(); self.hashes.clear(); Ok(OperatorResult::Finished(chunk.with_data(out))) } } -impl Operator for GenericJoinProbe { +impl Operator for GenericJoinProbe { fn execute( &mut self, context: &PExecutionContext, diff --git a/crates/polars-pipe/src/executors/sinks/joins/generic_probe_outer.rs b/crates/polars-pipe/src/executors/sinks/joins/generic_probe_outer.rs new file mode 100644 index 0000000000000..77db52b9f42c8 --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/joins/generic_probe_outer.rs @@ -0,0 +1,315 @@ +use std::sync::atomic::Ordering; + +use arrow::array::{Array, BinaryArray, MutablePrimitiveArray}; +use polars_core::export::ahash::RandomState; +use polars_core::prelude::*; +use polars_core::series::IsSorted; +use polars_ops::chunked_array::DfTake; +use polars_ops::frame::join::_finish_join; +use polars_ops::prelude::_coalesce_outer_join; +use smartstring::alias::String as SmartString; + +use crate::executors::sinks::joins::generic_build::*; +use crate::executors::sinks::joins::row_values::RowValues; +use crate::executors::sinks::joins::PartitionedMap; +use crate::executors::sinks::utils::hash_rows; +use crate::executors::sinks::ExtraPayload; +use crate::expressions::PhysicalPipedExpr; +use crate::operators::{DataChunk, Operator, OperatorResult, PExecutionContext}; + +#[derive(Clone)] +pub struct GenericOuterJoinProbe { + /// all chunks are stacked into a single dataframe + /// the dataframe is not rechunked. + df_a: Arc, + // Dummy needed for the flush phase. + df_b_dummy: Option, + /// The join columns are all tightly packed + /// the values of a join column(s) can be found + /// by: + /// first get the offset of the chunks and multiply that with the number of join + /// columns + /// * chunk_offset = (idx * n_join_keys) + /// * end = (offset + n_join_keys) + materialized_join_cols: Arc<[BinaryArray]>, + suffix: Arc, + hb: RandomState, + /// partitioned tables that will be used for probing. + /// stores the key and the chunk_idx, df_idx of the left table. + hash_tables: Arc>, + + // amortize allocations + // in inner join these are the left table + // in left join there are the right table + join_tuples_a: Vec, + // in inner join these are the right table + // in left join there are the left table + join_tuples_b: MutablePrimitiveArray, + hashes: Vec, + // the join order is swapped to ensure we hash the smaller table + swapped: bool, + // cached output names + output_names: Option>, + join_nulls: bool, + coalesce: bool, + thread_no: usize, + row_values: RowValues, + key_names_left: Arc<[SmartString]>, + key_names_right: Arc<[SmartString]>, +} + +impl GenericOuterJoinProbe { + #[allow(clippy::too_many_arguments)] + pub(super) fn new( + df_a: DataFrame, + materialized_join_cols: Arc<[BinaryArray]>, + suffix: Arc, + hb: RandomState, + hash_tables: Arc>, + join_columns_right: Arc>>, + swapped: bool, + // Re-use the hashes allocation of the build side. + amortized_hashes: Vec, + join_nulls: bool, + coalesce: bool, + key_names_left: Arc<[SmartString]>, + key_names_right: Arc<[SmartString]>, + ) -> Self { + GenericOuterJoinProbe { + df_a: Arc::new(df_a), + df_b_dummy: None, + materialized_join_cols, + suffix, + hb, + hash_tables, + join_tuples_a: vec![], + join_tuples_b: MutablePrimitiveArray::new(), + hashes: amortized_hashes, + swapped, + output_names: None, + join_nulls, + coalesce, + thread_no: 0, + row_values: RowValues::new(join_columns_right, false), + key_names_left, + key_names_right, + } + } + + fn finish_join(&mut self, left_df: DataFrame, right_df: DataFrame) -> PolarsResult { + fn inner( + left_df: DataFrame, + right_df: DataFrame, + suffix: &str, + swapped: bool, + output_names: &mut Option>, + ) -> PolarsResult { + let (mut left_df, right_df) = if swapped { + (right_df, left_df) + } else { + (left_df, right_df) + }; + Ok(match output_names { + None => { + let out = _finish_join(left_df, right_df, Some(suffix))?; + *output_names = Some(out.get_column_names_owned()); + out + }, + Some(names) => unsafe { + // SAFETY: + // if we have duplicate names, we overwrite + // them in the next snippet + left_df + .get_columns_mut() + .extend_from_slice(right_df.get_columns()); + left_df + .get_columns_mut() + .iter_mut() + .zip(names) + .for_each(|(s, name)| { + s.rename(name); + }); + left_df + }, + }) + } + + if self.coalesce { + let out = inner( + left_df.clone(), + right_df, + self.suffix.as_ref(), + self.swapped, + &mut self.output_names, + )?; + let l = self + .key_names_left + .iter() + .map(|s| s.as_str()) + .collect::>(); + let r = self + .key_names_right + .iter() + .map(|s| s.as_str()) + .collect::>(); + Ok(_coalesce_outer_join( + out, + &l, + &r, + Some(self.suffix.as_ref()), + &left_df, + )) + } else { + inner( + left_df.clone(), + right_df, + self.suffix.as_ref(), + self.swapped, + &mut self.output_names, + ) + } + } + + fn match_outer<'b, I>(&mut self, iter: I) + where + I: Iterator + 'b, + { + for (i, (h, row)) in iter { + let df_idx_right = i as IdxSize; + + let entry = self + .hash_tables + .raw_entry(*h) + .from_hash(*h, |key| { + compare_fn(key, *h, &self.materialized_join_cols, row) + }) + .map(|key_val| key_val.1); + + if let Some((indexes_left, tracker)) = entry { + // compiles to normal store: https://rust.godbolt.org/z/331hMo339 + tracker.get_tracker().store(true, Ordering::Relaxed); + + self.join_tuples_a.extend_from_slice(indexes_left); + self.join_tuples_b + .extend_constant(indexes_left.len(), Some(df_idx_right)); + } else { + self.join_tuples_a.push(ChunkId::null()); + self.join_tuples_b.push_value(df_idx_right); + } + } + } + + fn execute_outer( + &mut self, + context: &PExecutionContext, + chunk: &DataChunk, + ) -> PolarsResult { + self.join_tuples_a.clear(); + self.join_tuples_b.clear(); + + if self.df_b_dummy.is_none() { + self.df_b_dummy = Some(chunk.data.clear()) + } + + let mut hashes = std::mem::take(&mut self.hashes); + let rows = self + .row_values + .get_values(context, chunk, self.join_nulls)?; + hash_rows(&rows, &mut hashes, &self.hb); + + if self.join_nulls || rows.null_count() == 0 { + let iter = hashes.iter().zip(rows.values_iter()).enumerate(); + self.match_outer(iter); + } else { + let iter = hashes + .iter() + .zip(rows.iter()) + .enumerate() + .filter_map(|(i, (h, row))| row.map(|row| (i, (h, row)))); + self.match_outer(iter); + } + self.hashes = hashes; + + let left_df = unsafe { + self.df_a + ._take_opt_chunked_unchecked_seq(&self.join_tuples_a) + }; + let right_df = unsafe { + self.join_tuples_b.with_freeze(|idx| { + let idx = IdxCa::from(idx.clone()); + let out = chunk.data.take_unchecked_impl(&idx, false); + // Drop so that the freeze context can go back to mutable array. + drop(idx); + out + }) + }; + let out = self.finish_join(left_df, right_df)?; + Ok(OperatorResult::Finished(chunk.with_data(out))) + } + + fn execute_flush(&mut self) -> PolarsResult { + let ht = self.hash_tables.inner(); + let n = ht.len(); + self.join_tuples_a.clear(); + + ht.iter().enumerate().for_each(|(i, ht)| { + if i % n == self.thread_no { + ht.iter().for_each(|(_k, (idx_left, tracker))| { + let found_match = tracker.get_tracker().load(Ordering::Relaxed); + + if !found_match { + self.join_tuples_a.extend_from_slice(idx_left); + } + }) + } + }); + + let left_df = unsafe { + self.df_a + ._take_chunked_unchecked_seq(&self.join_tuples_a, IsSorted::Not) + }; + + let size = left_df.height(); + let right_df = self.df_b_dummy.as_ref().unwrap(); + + let right_df = unsafe { + DataFrame::new_no_checks( + right_df + .get_columns() + .iter() + .map(|s| Series::full_null(s.name(), size, s.dtype())) + .collect(), + ) + }; + + let out = self.finish_join(left_df, right_df)?; + Ok(OperatorResult::Finished(DataChunk::new(0, out))) + } +} + +impl Operator for GenericOuterJoinProbe { + fn execute( + &mut self, + context: &PExecutionContext, + chunk: &DataChunk, + ) -> PolarsResult { + self.execute_outer(context, chunk) + } + + fn flush(&mut self) -> PolarsResult { + self.execute_flush() + } + + fn must_flush(&self) -> bool { + true + } + + fn split(&self, thread_no: usize) -> Box { + let mut new = self.clone(); + new.thread_no = thread_no; + Box::new(new) + } + fn fmt(&self) -> &str { + "generic_outer_join_probe" + } +} diff --git a/crates/polars-pipe/src/executors/sinks/joins/mod.rs b/crates/polars-pipe/src/executors/sinks/joins/mod.rs index 5da9cfd715c29..53ba6e896b715 100644 --- a/crates/polars-pipe/src/executors/sinks/joins/mod.rs +++ b/crates/polars-pipe/src/executors/sinks/joins/mod.rs @@ -2,8 +2,100 @@ mod cross; mod generic_build; mod generic_probe_inner_left; +mod generic_probe_outer; +mod row_values; + +use std::hash::{BuildHasherDefault, Hash, Hasher}; +use std::sync::atomic::AtomicBool; #[cfg(feature = "cross_join")] pub(crate) use cross::*; pub(crate) use generic_build::GenericBuild; +use polars_core::hashing::IdHasher; +use polars_core::prelude::IdxSize; use polars_ops::prelude::JoinType; +use polars_utils::idx_vec::UnitVec; +use polars_utils::index::ChunkId; +use polars_utils::partitioned::PartitionedHashMap; + +trait ToRow { + fn get_row(&self) -> &[u8]; +} + +impl ToRow for &[u8] { + #[inline(always)] + fn get_row(&self) -> &[u8] { + self + } +} + +impl ToRow for Option<&[u8]> { + #[inline(always)] + fn get_row(&self) -> &[u8] { + self.unwrap() + } +} + +// This is the hash and the Index offset in the chunks and the index offset in the dataframe +#[derive(Copy, Clone, Debug)] +#[repr(C)] +pub(super) struct Key { + pub(super) hash: u64, + /// We use the MSB as tracker for outer join matches + /// So the 25th bit of the chunk_idx will be used for that. + idx: ChunkId, +} + +impl Key { + #[inline] + fn new(hash: u64, chunk_idx: IdxSize, df_idx: IdxSize) -> Self { + let idx = ChunkId::store(chunk_idx, df_idx); + Key { hash, idx } + } +} + +impl Hash for Key { + #[inline] + fn hash(&self, state: &mut H) { + state.write_u64(self.hash) + } +} + +pub(crate) trait ExtraPayload: Clone + Sync + Send + Default + 'static { + /// Tracker used in the outer join. + fn get_tracker(&self) -> &AtomicBool { + panic!() + } +} +impl ExtraPayload for () {} + +#[repr(transparent)] +pub(crate) struct Tracker { + inner: AtomicBool, +} + +impl Default for Tracker { + #[inline] + fn default() -> Self { + Self { + inner: AtomicBool::new(false), + } + } +} + +// Needed for the trait resolving. We should never hit this. +impl Clone for Tracker { + fn clone(&self) -> Self { + panic!() + } +} + +impl ExtraPayload for Tracker { + #[inline(always)] + fn get_tracker(&self) -> &AtomicBool { + &self.inner + } +} + +type PartitionedMap = + PartitionedHashMap, V), BuildHasherDefault>; diff --git a/crates/polars-pipe/src/executors/sinks/joins/row_values.rs b/crates/polars-pipe/src/executors/sinks/joins/row_values.rs new file mode 100644 index 0000000000000..b144e98cf87d4 --- /dev/null +++ b/crates/polars-pipe/src/executors/sinks/joins/row_values.rs @@ -0,0 +1,91 @@ +use std::sync::Arc; + +use arrow::array::{ArrayRef, BinaryArray, StaticArray}; +use arrow::compute::utils::combine_validities_and; +use polars_core::error::PolarsResult; +use polars_row::RowsEncoded; + +use crate::expressions::PhysicalPipedExpr; +use crate::operators::{DataChunk, PExecutionContext}; + +#[derive(Clone)] +pub(super) struct RowValues { + current_rows: RowsEncoded, + join_column_eval: Arc>>, + join_columns_material: Vec, + // Location of join columns. + // These column locations need to be dropped from the rhs + pub join_column_idx: Option>, + det_join_idx: bool, +} + +impl RowValues { + pub(super) fn new( + join_column_eval: Arc>>, + det_join_idx: bool, + ) -> Self { + Self { + current_rows: Default::default(), + join_column_eval, + join_column_idx: None, + join_columns_material: vec![], + det_join_idx, + } + } + + pub(super) fn clear(&mut self) { + self.join_columns_material.clear(); + } + + pub(super) fn get_values( + &mut self, + context: &PExecutionContext, + chunk: &DataChunk, + join_nulls: bool, + ) -> PolarsResult> { + // Memory should already be cleared on previous iteration. + debug_assert!(self.join_columns_material.is_empty()); + let determine_idx = self.det_join_idx && self.join_column_idx.is_none(); + let mut names = vec![]; + + for phys_e in self.join_column_eval.iter() { + let s = phys_e.evaluate(chunk, context.execution_state.as_any())?; + let s = s.to_physical_repr().rechunk(); + if determine_idx { + names.push(s.name().to_string()); + } + self.join_columns_material.push(s.array_ref(0).clone()); + } + + // We determine the indices of the columns that have to be removed + // if swapped the join column is already removed from the `build_df` as that will + // be the rhs one. + if determine_idx { + let mut idx = names + .iter() + .filter_map(|name| chunk.data.get_column_index(name)) + .collect::>(); + // Ensure that it is sorted so that we can later remove columns in + // a predictable order + idx.sort_unstable(); + self.join_column_idx = Some(idx); + } + polars_row::convert_columns_amortized_no_order( + &self.join_columns_material, + &mut self.current_rows, + ); + + // SAFETY: we keep rows-encode alive + let array = unsafe { self.current_rows.borrow_array() }; + Ok(if join_nulls { + array + } else { + let validity = self + .join_columns_material + .iter() + .map(|arr| arr.validity().cloned()) + .fold(None, |l, r| combine_validities_and(l.as_ref(), r.as_ref())); + array.with_validity_typed(validity) + }) + } +} diff --git a/crates/polars-pipe/src/executors/sinks/memory.rs b/crates/polars-pipe/src/executors/sinks/memory.rs index 43aad94e02c48..a5540889ddb7d 100644 --- a/crates/polars-pipe/src/executors/sinks/memory.rs +++ b/crates/polars-pipe/src/executors/sinks/memory.rs @@ -41,7 +41,7 @@ impl MemTracker { } /// This shouldn't be called often as this is expensive. - fn refresh_memory(&self) { + pub fn refresh_memory(&self) { self.available_mem .store(MEMINFO.free() as usize, Ordering::Relaxed); } @@ -57,6 +57,12 @@ impl MemTracker { self.available_mem.load(Ordering::Relaxed) } + pub(super) fn get_available_latest(&self) -> usize { + self.refresh_memory(); + self.fetch_count.store(0, Ordering::Relaxed); + self.available_mem.load(Ordering::Relaxed) + } + pub(super) fn free_memory_fraction_since_start(&self) -> f64 { // We divide first to reduce the precision loss in floats. // We also add 1.0 to available_at_start to prevent division by zero. diff --git a/crates/polars-pipe/src/executors/sinks/mod.rs b/crates/polars-pipe/src/executors/sinks/mod.rs index d783e5728fed7..86d46eb3082cd 100644 --- a/crates/polars-pipe/src/executors/sinks/mod.rs +++ b/crates/polars-pipe/src/executors/sinks/mod.rs @@ -9,6 +9,8 @@ mod slice; mod sort; mod utils; +use std::sync::OnceLock; + pub(crate) use joins::*; pub(crate) use ordered::*; #[cfg(any( @@ -26,15 +28,16 @@ pub(crate) use sort::*; // Overallocation seems a lot more expensive than resizing so we start reasonable small. const HASHMAP_INIT_SIZE: usize = 64; -pub(crate) static POLARS_TEMP_DIR: &str = "POLARS_TEMP_DIR"; - -pub(crate) fn get_base_temp_dir() -> String { - let base_dir = std::env::var(POLARS_TEMP_DIR) - .unwrap_or_else(|_| std::env::temp_dir().to_string_lossy().into_owned()); +pub(crate) static POLARS_TEMP_DIR: OnceLock = OnceLock::new(); - if polars_core::config::verbose() { - eprintln!("Temporary directory path in use: {}", base_dir); - } +pub(crate) fn get_base_temp_dir() -> &'static str { + POLARS_TEMP_DIR.get_or_init(|| { + let tmp = std::env::var("POLARS_TEMP_DIR") + .unwrap_or_else(|_| std::env::temp_dir().to_string_lossy().into_owned()); - base_dir + if polars_core::config::verbose() { + eprintln!("Temporary directory path in use: {}", &tmp); + } + tmp + }) } diff --git a/crates/polars-pipe/src/executors/sinks/output/ipc.rs b/crates/polars-pipe/src/executors/sinks/output/ipc.rs index f7cbab92248a2..e0a479f329666 100644 --- a/crates/polars-pipe/src/executors/sinks/output/ipc.rs +++ b/crates/polars-pipe/src/executors/sinks/output/ipc.rs @@ -1,5 +1,4 @@ use std::path::Path; -use std::sync::Arc; use crossbeam_channel::bounded; use polars_core::prelude::*; diff --git a/crates/polars-pipe/src/executors/sinks/output/parquet.rs b/crates/polars-pipe/src/executors/sinks/output/parquet.rs index c1c79a44e5776..49a83d0227f81 100644 --- a/crates/polars-pipe/src/executors/sinks/output/parquet.rs +++ b/crates/polars-pipe/src/executors/sinks/output/parquet.rs @@ -1,21 +1,64 @@ +use std::any::Any; use std::path::Path; +use std::thread::JoinHandle; -use crossbeam_channel::bounded; +use crossbeam_channel::{bounded, Receiver, Sender}; use polars_core::prelude::*; -use polars_io::parquet::ParquetWriter; +use polars_io::parquet::{BatchedWriter, ParquetWriter, RowGroupIter}; use polars_plan::prelude::ParquetWriteOptions; use crate::executors::sinks::output::file_sink::{init_writer_thread, FilesSink, SinkWriter}; +use crate::operators::{DataChunk, FinalizedSink, PExecutionContext, Sink, SinkResult}; use crate::pipeline::morsels_per_sink; -pub struct ParquetSink {} +type RowGroups = Vec>; + +pub(super) fn init_row_group_writer_thread( + receiver: Receiver>, + writer: Arc>, + // this is used to determine when a batch of chunks should be written to disk + // all chunks per push should be collected to determine in which order they should + // be written + morsels_per_sink: usize, +) -> JoinHandle<()> { + std::thread::spawn(move || { + // keep chunks around until all chunks per sink are written + // then we write them all at once. + let mut batched = Vec::with_capacity(morsels_per_sink); + while let Ok(rgs) = receiver.recv() { + // `last_write` indicates if all chunks are processed, e.g. this is the last write. + // this is when `write_chunks` is called with `None`. + let last_write = if let Some(rgs) = rgs { + batched.push(rgs); + false + } else { + true + }; + + if batched.len() == morsels_per_sink || last_write { + batched.sort_by_key(|chunk| chunk.0); + + for (_, rg) in batched.drain(0..) { + writer.write_row_groups(rg).unwrap() + } + } + if last_write { + writer.finish().unwrap(); + return; + } + } + }) +} + +#[derive(Clone)] +pub struct ParquetSink { + writer: Arc>, + io_thread_handle: Arc>>, + sender: Sender>, +} impl ParquetSink { #[allow(clippy::new_ret_no_self)] - pub fn new( - path: &Path, - options: ParquetWriteOptions, - schema: &Schema, - ) -> PolarsResult { + pub fn new(path: &Path, options: ParquetWriteOptions, schema: &Schema) -> PolarsResult { let file = std::fs::File::create(path)?; let writer = ParquetWriter::new(file) .with_compression(options.compression) @@ -27,26 +70,74 @@ impl ParquetSink { .set_parallel(false) .batched(schema)?; - let writer = Box::new(writer) as Box; - + let writer = Arc::new(writer); let morsels_per_sink = morsels_per_sink(); - let backpressure = morsels_per_sink * 2; + + let backpressure = morsels_per_sink * 4; let (sender, receiver) = bounded(backpressure); - let io_thread_handle = Arc::new(Some(init_writer_thread( + let io_thread_handle = Arc::new(Some(init_row_group_writer_thread( receiver, - writer, - options.maintain_order, + writer.clone(), morsels_per_sink, ))); - Ok(FilesSink { - sender, + Ok(Self { + writer, io_thread_handle, + sender, }) } } +impl Sink for ParquetSink { + fn sink(&mut self, _context: &PExecutionContext, chunk: DataChunk) -> PolarsResult { + // Encode and compress row-groups on every thread. + let row_groups = self + .writer + .encode_and_compress(&chunk.data) + .collect::>>()?; + // Only then send the compressed pages to the writer. + self.sender + .send(Some((chunk.chunk_index, row_groups))) + .unwrap(); + Ok(SinkResult::CanHaveMoreInput) + } + + fn combine(&mut self, _other: &mut dyn Sink) { + // Nothing to do + } + + fn split(&self, _thread_no: usize) -> Box { + Box::new(self.clone()) + } + + fn finalize(&mut self, _context: &PExecutionContext) -> PolarsResult { + // `None` indicates that we can flush all remaining chunks. + self.sender.send(None).unwrap(); + + // wait until all files written + // some unwrap/mut kung-fu to get a hold of `self` + Arc::get_mut(&mut self.io_thread_handle) + .unwrap() + .take() + .unwrap() + .join() + .unwrap(); + + // return a dummy dataframe; + Ok(FinalizedSink::Finished(Default::default())) + } + + fn as_any(&mut self) -> &mut dyn Any { + self + } + + fn fmt(&self) -> &str { + "parquet_sink" + } +} + #[cfg(feature = "cloud")] pub struct ParquetCloudSink {} #[cfg(feature = "cloud")] @@ -79,7 +170,7 @@ impl ParquetCloudSink { let io_thread_handle = Arc::new(Some(init_writer_thread( receiver, writer, - parquet_options.maintain_order, + true, morsels_per_sink, ))); diff --git a/crates/polars-pipe/src/executors/sinks/reproject.rs b/crates/polars-pipe/src/executors/sinks/reproject.rs index aee32b9cad487..8d66e102fd92d 100644 --- a/crates/polars-pipe/src/executors/sinks/reproject.rs +++ b/crates/polars-pipe/src/executors/sinks/reproject.rs @@ -2,7 +2,6 @@ use std::any::Any; use polars_core::schema::SchemaRef; -use crate::executors::operators::ReProjectOperator; use crate::executors::sources::ReProjectSource; use crate::operators::{ DataChunk, FinalizedSink, PExecutionContext, PolarsResult, Sink, SinkResult, @@ -43,12 +42,10 @@ impl Sink for ReProjectSink { FinalizedSink::Finished(df) => { FinalizedSink::Finished(df.select(self.schema.iter_names())?) }, - FinalizedSink::Operator(op) => { - FinalizedSink::Operator(Box::new(ReProjectOperator::new(self.schema.clone(), op))) - }, FinalizedSink::Source(source) => { FinalizedSink::Source(Box::new(ReProjectSource::new(self.schema.clone(), source))) }, + _ => unimplemented!(), }) } diff --git a/crates/polars-pipe/src/executors/sinks/sort/ooc.rs b/crates/polars-pipe/src/executors/sinks/sort/ooc.rs index 60547ec6c076c..64acfa30a5db0 100644 --- a/crates/polars-pipe/src/executors/sinks/sort/ooc.rs +++ b/crates/polars-pipe/src/executors/sinks/sort/ooc.rs @@ -1,5 +1,6 @@ use std::path::Path; use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; +use std::time::Instant; use crossbeam_queue::SegQueue; use polars_core::prelude::*; @@ -14,6 +15,7 @@ use polars_ops::prelude::*; use rayon::prelude::*; use crate::executors::sinks::io::{DfIter, IOThread}; +use crate::executors::sinks::memory::MemTracker; use crate::executors::sinks::sort::source::SortSource; use crate::operators::FinalizedSink; @@ -27,7 +29,6 @@ pub(super) fn read_df(path: &Path) -> PolarsResult { // and amortize IO cost #[derive(Default)] struct PartitionSpillBuf { - row_count: AtomicU32, // keep track of the length // that's cheaper than iterating the linked list len: AtomicU32, @@ -36,20 +37,15 @@ struct PartitionSpillBuf { } impl PartitionSpillBuf { - fn push(&self, df: DataFrame) -> Option { + fn push(&self, df: DataFrame, spill_limit: u64) -> Option { debug_assert!(df.height() > 0); - let acc = self - .row_count - .fetch_add(df.height() as u32, Ordering::Relaxed); let size = self .size .fetch_add(df.estimated_size() as u64, Ordering::Relaxed); - let larger_than_32_mb = size > 1 << 25; let len = self.len.fetch_add(1, Ordering::Relaxed); self.chunks.push(df); - if acc > 50_000 || larger_than_32_mb { - // reset all statistics - self.row_count.store(0, Ordering::Relaxed); + if size > spill_limit { + // Reset all statistics. self.len.store(0, Ordering::Relaxed); self.size.store(0, Ordering::Relaxed); // other threads can be pushing while we drain @@ -63,36 +59,55 @@ impl PartitionSpillBuf { } } - fn finish(self) -> Option { + fn finish(&self) -> Option { if !self.chunks.is_empty() { - let iter = self.chunks.into_iter(); - Some(accumulate_dataframes_vertical_unchecked(iter)) + let len = self.len.load(Ordering::Relaxed) + 1; + let mut out = Vec::with_capacity(len as usize); + while let Some(df) = self.chunks.pop() { + out.push(df) + } + Some(accumulate_dataframes_vertical_unchecked(out)) } else { None } } } -struct PartitionSpiller { +pub(crate) struct PartitionSpiller { partitions: Vec, + // Spill limit in bytes. + spill_limit: u64, } impl PartitionSpiller { - fn new(n_parts: usize) -> Self { + fn new(n_parts: usize, spill_limit: u64) -> Self { let mut partitions = vec![]; partitions.resize_with(n_parts + 1, PartitionSpillBuf::default); - Self { partitions } + Self { + partitions, + spill_limit, + } } fn push(&self, partition: usize, df: DataFrame) -> Option { - self.partitions[partition].push(df) + self.partitions[partition].push(df, self.spill_limit) + } + + pub(crate) fn get(&self, partition: usize) -> Option { + self.partitions[partition].finish() + } + + pub(crate) fn len(&self) -> usize { + self.partitions.len() } - fn spill_all(self, io_thread: &IOThread) { + #[cfg(debug_assertions)] + // Used in testing only. + fn spill_all(&self, io_thread: &IOThread) { let min_len = std::cmp::max(self.partitions.len() / POOL.current_num_threads(), 2); POOL.install(|| { self.partitions - .into_par_iter() + .par_iter() .with_min_len(min_len) .enumerate() .for_each(|(part, part_buf)| { @@ -100,21 +115,35 @@ impl PartitionSpiller { io_thread.dump_partition_local(part as IdxSize, df) } }) - }) + }); + eprintln!("PARTITIONED FORCE SPILLED") } } +#[allow(clippy::too_many_arguments)] pub(super) fn sort_ooc( - io_thread: &IOThread, + io_thread: IOThread, // these partitions are the samples // these are not yet assigned to a buckets samples: Series, idx: usize, descending: bool, + nulls_last: bool, slice: Option<(i64, usize)>, verbose: bool, + memtrack: MemTracker, + ooc_start: Instant, ) -> PolarsResult { + let now = Instant::now(); + let multithreaded_partition = std::env::var("POLARS_OOC_SORT_PAR_PARTITION").is_ok(); + let spill_size = std::env::var("POLARS_OOC_SORT_SPILL_SIZE") + .map(|v| v.parse::().expect("integer")) + .unwrap_or(1 << 26); let samples = samples.to_physical_repr().into_owned(); + let spill_size = std::cmp::min( + memtrack.get_available_latest() / (samples.len() * 3), + spill_size, + ); // we collect as I am not sure that if we write to the same directory the // iterator will read those also. @@ -123,10 +152,11 @@ pub(super) fn sort_ooc( let files = std::fs::read_dir(dir)?.collect::>>()?; if verbose { + eprintln!("spill size: {} mb", spill_size / 1024 / 1024); eprintln!("processing {} files", files.len()); } - let partitions_spiller = PartitionSpiller::new(samples.len()); + let partitions_spiller = PartitionSpiller::new(samples.len(), spill_size as u64); POOL.install(|| { files.par_iter().try_for_each(|entry| { @@ -141,18 +171,27 @@ pub(super) fn sort_ooc( let assigned_parts = det_partitions(sort_col, &samples, descending); // partition the dataframe into proper buckets - let (iter, unique_assigned_parts) = partition_df(df, &assigned_parts)?; + let (iter, unique_assigned_parts) = + partition_df(df, &assigned_parts, multithreaded_partition)?; for (part, df) in unique_assigned_parts.into_no_null_iter().zip(iter) { if let Some(df) = partitions_spiller.push(part as usize, df) { io_thread.dump_partition_local(part, df) } } + io_thread.clean(path); PolarsResult::Ok(()) }) })?; - partitions_spiller.spill_all(io_thread); if verbose { - eprintln!("finished partitioning sort files"); + eprintln!("partitioning sort took: {:?}", now.elapsed()); + } + + // Branch for testing so we hit different parts in the Source phase. + #[cfg(debug_assertions)] + { + if std::env::var("POLARS_SPILL_SORT_PARTITIONS").is_ok() { + partitions_spiller.spill_all(&io_thread) + } } let files = std::fs::read_dir(dir)? @@ -172,7 +211,18 @@ pub(super) fn sort_ooc( }) .collect::>>()?; - let source = SortSource::new(files, idx, descending, slice, verbose); + let source = SortSource::new( + files, + idx, + descending, + nulls_last, + slice, + verbose, + io_thread, + memtrack, + ooc_start, + partitions_spiller, + ); Ok(FinalizedSink::Source(Box::new(source))) } @@ -182,8 +232,12 @@ fn det_partitions(s: &Series, partitions: &Series, descending: bool) -> IdxCa { search_sorted(partitions, &s, SearchSortedSide::Any, descending).unwrap() } -fn partition_df(df: DataFrame, partitions: &IdxCa) -> PolarsResult<(DfIter, IdxCa)> { - let groups = partitions.group_tuples(false, false)?; +fn partition_df( + df: DataFrame, + partitions: &IdxCa, + multithreaded: bool, +) -> PolarsResult<(DfIter, IdxCa)> { + let groups = partitions.group_tuples(multithreaded, false)?; let partitions = unsafe { partitions.clone().into_series().agg_first(&groups) }; let partitions = partitions.idx().unwrap().clone(); @@ -191,7 +245,9 @@ fn partition_df(df: DataFrame, partitions: &IdxCa) -> PolarsResult<(DfIter, IdxC GroupsProxy::Idx(idx) => { let iter = idx.into_iter().map(move |(_, group)| { // groups are in bounds and sorted - unsafe { df._take_unchecked_slice_sorted(&group, false, IsSorted::Ascending) } + unsafe { + df._take_unchecked_slice_sorted(&group, multithreaded, IsSorted::Ascending) + } }); Box::new(iter) as DfIter }, diff --git a/crates/polars-pipe/src/executors/sinks/sort/sink.rs b/crates/polars-pipe/src/executors/sinks/sort/sink.rs index b054282a32575..0f66db59b5a3a 100644 --- a/crates/polars-pipe/src/executors/sinks/sort/sink.rs +++ b/crates/polars-pipe/src/executors/sinks/sort/sink.rs @@ -1,5 +1,6 @@ use std::any::Any; use std::sync::{Arc, RwLock}; +use std::time::Instant; use polars_core::config::verbose; use polars_core::error::PolarsResult; @@ -35,6 +36,8 @@ pub struct SortSink { current_chunk_rows: usize, // total bytes of tables in current chunks current_chunks_size: usize, + // Start time of OOC phase. + ooc_start: Option, } impl SortSink { @@ -54,9 +57,12 @@ impl SortSink { dist_sample: vec![], current_chunk_rows: 0, current_chunks_size: 0, + ooc_start: None, }; if ooc { - eprintln!("OOC sort forced"); + if verbose() { + eprintln!("OOC sort forced"); + } out.init_ooc().unwrap(); } out @@ -66,6 +72,7 @@ impl SortSink { if verbose() { eprintln!("OOC sort started"); } + self.ooc_start = Some(Instant::now()); self.ooc = true; // start IO thread @@ -99,10 +106,8 @@ impl SortSink { } fn dump(&mut self, force: bool) -> PolarsResult<()> { - let larger_than_32_mb = self.current_chunks_size > 1 << 25; - if (force || larger_than_32_mb || self.current_chunk_rows > 50_000) - && !self.chunks.is_empty() - { + let larger_than_32_mb = self.current_chunks_size > (1 << 25); + if (force || larger_than_32_mb) && !self.chunks.is_empty() { // into a single chunk because multiple file IO's is expensive // and may lead to many smaller files in ooc-sort later, which is exponentially // expensive @@ -163,6 +168,7 @@ impl Sink for SortSink { dist_sample: vec![], current_chunk_rows: 0, current_chunks_size: 0, + ooc_start: self.ooc_start, }) } @@ -170,8 +176,8 @@ impl Sink for SortSink { if self.ooc { // spill everything self.dump(true).unwrap(); - let lock = self.io_thread.read().unwrap(); - let io_thread = lock.as_ref().unwrap(); + let mut lock = self.io_thread.write().unwrap(); + let io_thread = lock.take().unwrap(); let dist = Series::from_any_values("", &self.dist_sample, true).unwrap(); let dist = dist.sort_with(SortOptions { @@ -181,15 +187,25 @@ impl Sink for SortSink { maintain_order: self.sort_args.maintain_order, }); - block_thread_until_io_thread_done(io_thread); + let instant = self.ooc_start.unwrap(); + if context.verbose { + eprintln!("finished sinking into OOC sort in {:?}", instant.elapsed()); + } + block_thread_until_io_thread_done(&io_thread); + if context.verbose { + eprintln!("full file dump of OOC sort took {:?}", instant.elapsed()); + } sort_ooc( io_thread, dist, self.sort_idx, self.sort_args.descending[0], + self.sort_args.nulls_last, self.sort_args.slice, context.verbose, + self.mem_track.clone(), + instant, ) } else { let chunks = std::mem::take(&mut self.chunks); @@ -199,6 +215,7 @@ impl Sink for SortSink { self.sort_idx, self.sort_args.descending[0], self.sort_args.slice, + self.sort_args.nulls_last, )?; Ok(FinalizedSink::Finished(df)) } @@ -218,6 +235,7 @@ pub(super) fn sort_accumulated( sort_idx: usize, descending: bool, slice: Option<(i64, usize)>, + nulls_last: bool, ) -> PolarsResult { // This is needed because we can have empty blocks and we require chunks to have single chunks. df.as_single_chunk_par(); @@ -225,7 +243,7 @@ pub(super) fn sort_accumulated( df.sort_impl( vec![sort_column], vec![descending], - false, + nulls_last, false, slice, true, diff --git a/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs b/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs index 764b12e9e82fd..6fdede156fcf6 100644 --- a/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs +++ b/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs @@ -1,6 +1,6 @@ use std::any::Any; -use arrow::array::{ArrayRef, BinaryArray}; +use arrow::array::BinaryArray; use polars_core::prelude::sort::_broadcast_descending; use polars_core::prelude::sort::arg_sort_multiple::_get_rows_encoded_compat_array; use polars_core::prelude::*; @@ -300,7 +300,7 @@ impl Sink for SortSinkMultiple { output_schema: self.output_schema.clone(), }))), // SortSink should not produce this branch - FinalizedSink::Operator(_) => unreachable!(), + FinalizedSink::Operator => unreachable!(), } } diff --git a/crates/polars-pipe/src/executors/sinks/sort/source.rs b/crates/polars-pipe/src/executors/sinks/sort/source.rs index 4d40efc8a5bb2..edc49130c6367 100644 --- a/crates/polars-pipe/src/executors/sinks/sort/source.rs +++ b/crates/polars-pipe/src/executors/sinks/sort/source.rs @@ -1,32 +1,51 @@ +use std::iter::Peekable; use std::path::PathBuf; +use std::time::Instant; use polars_core::prelude::*; use polars_core::utils::{accumulate_dataframes_vertical_unchecked, split_df}; use polars_core::POOL; use rayon::prelude::*; -use crate::executors::sinks::sort::ooc::read_df; +use crate::executors::sinks::io::IOThread; +use crate::executors::sinks::memory::MemTracker; +use crate::executors::sinks::sort::ooc::{read_df, PartitionSpiller}; use crate::executors::sinks::sort::sink::sort_accumulated; use crate::executors::sources::get_source_index; use crate::operators::{DataChunk, PExecutionContext, Source, SourceResult}; pub struct SortSource { - files: std::vec::IntoIter<(u32, PathBuf)>, + files: Peekable>, n_threads: usize, sort_idx: usize, descending: bool, + nulls_last: bool, chunk_offset: IdxSize, slice: Option<(i64, usize)>, finished: bool, + io_thread: IOThread, + memtrack: MemTracker, + // Start of the Source phase + source_start: Instant, + // Start of the OOC sort operation. + ooc_start: Instant, + partition_spiller: PartitionSpiller, + current_part: usize, } impl SortSource { + #[allow(clippy::too_many_arguments)] pub(super) fn new( mut files: Vec<(u32, PathBuf)>, sort_idx: usize, descending: bool, + nulls_last: bool, slice: Option<(i64, usize)>, verbose: bool, + io_thread: IOThread, + memtrack: MemTracker, + ooc_start: Instant, + partition_spiller: PartitionSpiller, ) -> Self { if verbose { eprintln!("started sort source phase"); @@ -35,16 +54,23 @@ impl SortSource { files.sort_unstable_by_key(|entry| entry.0); let n_threads = POOL.current_num_threads(); - let files = files.into_iter(); + let files = files.into_iter().peekable(); Self { files, n_threads, sort_idx, descending, + nulls_last, chunk_offset: get_source_index(1) as IdxSize, slice, finished: false, + io_thread, + memtrack, + source_start: Instant::now(), + ooc_start, + partition_spiller, + current_part: 0, } } fn finish_batch(&mut self, dfs: Vec) -> Vec { @@ -59,57 +85,138 @@ impl SortSource { }) .collect() } + + fn finish_from_df(&mut self, df: DataFrame) -> PolarsResult { + // Sort a single partition + // We always need to sort again! + let current_slice = self.slice; + + let mut df = match &mut self.slice { + None => sort_accumulated(df, self.sort_idx, self.descending, None, self.nulls_last), + Some((offset, len)) => { + let df_len = df.height(); + debug_assert!(*offset >= 0); + let out = if *offset as usize >= df_len { + *offset -= df_len as i64; + Ok(df.slice(0, 0)) + } else { + let out = sort_accumulated( + df, + self.sort_idx, + self.descending, + current_slice, + self.nulls_last, + ); + *len = len.saturating_sub(df_len); + *offset = 0; + out + }; + if *len == 0 { + self.finished = true; + } + out + }, + }?; + + // convert to chunks + let dfs = split_df(&mut df, self.n_threads)?; + Ok(SourceResult::GotMoreData(self.finish_batch(dfs))) + } + fn print_verbose(&self, verbose: bool) { + if verbose { + eprintln!("sort source phase took: {:?}", self.source_start.elapsed()); + eprintln!("full ooc sort took: {:?}", self.ooc_start.elapsed()); + } + } + + fn get_from_memory( + &mut self, + read: &mut Vec, + read_size: &mut usize, + part: usize, + keep_track: bool, + ) { + while self.current_part <= part { + if let Some(df) = self.partition_spiller.get(self.current_part - 1) { + if keep_track { + *read_size += df.estimated_size(); + } + read.push(df); + } + self.current_part += 1; + } + } } impl Source for SortSource { - fn get_batches(&mut self, _context: &PExecutionContext) -> PolarsResult { + fn get_batches(&mut self, context: &PExecutionContext) -> PolarsResult { // early return - if self.finished { + if self.finished || self.current_part >= self.partition_spiller.len() { + self.print_verbose(context.verbose); return Ok(SourceResult::Finished); } + self.current_part += 1; + let mut read_size = 0; + let mut read = vec![]; match self.files.next() { - None => Ok(SourceResult::Finished), - Some((_, path)) => { - let files = std::fs::read_dir(path)?.collect::>>()?; - - // read the files in a single partition in parallel - let dfs = POOL.install(|| { - files - .par_iter() - .map(|entry| read_df(&entry.path())) - .collect::>>() - })?; - let df = accumulate_dataframes_vertical_unchecked(dfs); - // sort a single partition - // We always need to sort again! - // We cannot trust - let current_slice = self.slice; - let mut df = match &mut self.slice { - None => sort_accumulated(df, self.sort_idx, self.descending, None), - Some((offset, len)) => { - let df_len = df.height(); - assert!(*offset >= 0); - let out = if *offset as usize >= df_len { - *offset -= df_len as i64; - Ok(df.slice(0, 0)) - } else { - let out = - sort_accumulated(df, self.sort_idx, self.descending, current_slice); - *len = len.saturating_sub(df_len); - *offset = 0; - out - }; - if *len == 0 { - self.finished = true; - } - out - }, - }?; - - // convert to chunks - let dfs = split_df(&mut df, self.n_threads)?; - Ok(SourceResult::GotMoreData(self.finish_batch(dfs))) + None => { + // Ensure we fetch all from memory. + self.get_from_memory( + &mut read, + &mut read_size, + self.partition_spiller.len(), + false, + ); + if read.is_empty() { + self.print_verbose(context.verbose); + Ok(SourceResult::Finished) + } else { + self.finished = true; + let df = accumulate_dataframes_vertical_unchecked(read); + self.finish_from_df(df) + } + }, + Some((mut partition, mut path)) => { + self.get_from_memory(&mut read, &mut read_size, partition as usize, true); + let limit = self.memtrack.get_available() / 3; + + loop { + if let Some(in_mem) = self.partition_spiller.get(partition as usize) { + read_size += in_mem.estimated_size(); + read.push(in_mem) + } + + let files = std::fs::read_dir(&path)?.collect::>>()?; + + // read the files in a single partition in parallel + let dfs = POOL.install(|| { + files + .par_iter() + .map(|entry| { + let df = read_df(&entry.path())?; + Ok(df) + }) + .collect::>>() + })?; + + let df = accumulate_dataframes_vertical_unchecked(dfs); + read_size += df.estimated_size(); + read.push(df); + if read_size > limit { + break; + } + + let Some((next_part, next_path)) = self.files.next() else { + break; + }; + path = next_path; + partition = next_part; + } + let df = accumulate_dataframes_vertical_unchecked(read); + let out = self.finish_from_df(df); + self.io_thread.clean(path); + out }, } } diff --git a/crates/polars-pipe/src/executors/sources/parquet.rs b/crates/polars-pipe/src/executors/sources/parquet.rs index d12791137ca01..85e19f4bfc718 100644 --- a/crates/polars-pipe/src/executors/sources/parquet.rs +++ b/crates/polars-pipe/src/executors/sources/parquet.rs @@ -4,7 +4,7 @@ use std::path::PathBuf; use std::sync::Arc; use arrow::datatypes::ArrowSchemaRef; -use polars_core::config::get_file_prefetch_size; +use polars_core::config::{env_force_async, get_file_prefetch_size}; use polars_core::error::*; use polars_core::prelude::Series; use polars_core::POOL; @@ -204,8 +204,7 @@ impl ParquetSource { if verbose { eprintln!("POLARS PREFETCH_SIZE: {}", prefetch_size) } - let run_async = paths.first().map(is_cloud_url).unwrap_or(false) - || std::env::var("POLARS_FORCE_ASYNC").as_deref().unwrap_or("") == "1"; + let run_async = paths.first().map(is_cloud_url).unwrap_or(false) || env_force_async(); let mut source = ParquetSource { batched_readers: VecDeque::new(), @@ -230,10 +229,8 @@ impl ParquetSource { } Ok(source) } -} -impl Source for ParquetSource { - fn get_batches(&mut self, _context: &PExecutionContext) -> PolarsResult { + fn prefetch_files(&mut self) -> PolarsResult<()> { // We already start downloading the next file, we can only do that if we don't have a limit. // In the case of a limit we first must update the row count with the batch results. // @@ -269,6 +266,13 @@ impl Source for ParquetSource { } } } + Ok(()) + } +} + +impl Source for ParquetSource { + fn get_batches(&mut self, _context: &PExecutionContext) -> PolarsResult { + self.prefetch_files()?; let Some(mut reader) = self.batched_readers.pop_front() else { // If there was no new reader, we depleted all of them and are finished. diff --git a/crates/polars-pipe/src/lib.rs b/crates/polars-pipe/src/lib.rs index b2724e9a8981d..bd18b177e4f2f 100644 --- a/crates/polars-pipe/src/lib.rs +++ b/crates/polars-pipe/src/lib.rs @@ -1,5 +1,3 @@ -extern crate core; - mod executors; pub mod expressions; pub mod operators; diff --git a/crates/polars-pipe/src/operators/chunks.rs b/crates/polars-pipe/src/operators/chunks.rs index 55a4975970b8f..6fb289b73b50c 100644 --- a/crates/polars-pipe/src/operators/chunks.rs +++ b/crates/polars-pipe/src/operators/chunks.rs @@ -44,7 +44,7 @@ pub(crate) fn chunks_to_df_unchecked(chunks: Vec) -> DataFrame { /// /// Changing the `DataFrame` into contiguous chunks is the caller's /// responsibility. -#[cfg(feature = "parquet")] +#[cfg(any(feature = "parquet", feature = "ipc", feature = "csv"))] #[derive(Clone)] pub(crate) struct StreamingVstacker { current_dataframe: Option, @@ -52,7 +52,7 @@ pub(crate) struct StreamingVstacker { output_chunk_size: usize, } -#[cfg(feature = "parquet")] +#[cfg(any(feature = "parquet", feature = "ipc", feature = "csv"))] impl StreamingVstacker { /// Create a new instance. pub fn new(output_chunk_size: usize) -> Self { @@ -103,7 +103,7 @@ impl StreamingVstacker { } } -#[cfg(feature = "parquet")] +#[cfg(any(feature = "parquet", feature = "ipc", feature = "csv"))] impl Default for StreamingVstacker { /// 4 MB was chosen based on some empirical experiments that showed it to /// be decently faster than lower or higher values, and it's small enough @@ -114,7 +114,7 @@ impl Default for StreamingVstacker { } #[cfg(test)] -#[cfg(feature = "parquet")] +#[cfg(any(feature = "parquet", feature = "ipc", feature = "csv"))] mod test { use super::*; diff --git a/crates/polars-pipe/src/operators/sink.rs b/crates/polars-pipe/src/operators/sink.rs index 19b60aef1772c..a1933c346b5f0 100644 --- a/crates/polars-pipe/src/operators/sink.rs +++ b/crates/polars-pipe/src/operators/sink.rs @@ -1,4 +1,7 @@ use std::any::Any; +use std::fmt::{Debug, Formatter}; + +use polars_utils::arena::Node; use super::*; @@ -10,10 +13,21 @@ pub enum SinkResult { pub enum FinalizedSink { Finished(DataFrame), - Operator(Box), + Operator, Source(Box), } +impl Debug for FinalizedSink { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let s = match self { + FinalizedSink::Finished(_) => "finished", + FinalizedSink::Operator => "operator", + FinalizedSink::Source(_) => "source", + }; + write!(f, "{s}") + } +} + pub trait Sink: Send + Sync { fn sink(&mut self, context: &PExecutionContext, chunk: DataChunk) -> PolarsResult; @@ -26,4 +40,13 @@ pub trait Sink: Send + Sync { fn as_any(&mut self) -> &mut dyn Any; fn fmt(&self) -> &str; + + fn is_join_build(&self) -> bool { + false + } + + // Only implemented for Join sinks + fn node(&self) -> Node { + unimplemented!() + } } diff --git a/crates/polars-pipe/src/pipeline/convert.rs b/crates/polars-pipe/src/pipeline/convert.rs index f0b83a60fdd7c..0b841b504027c 100644 --- a/crates/polars-pipe/src/pipeline/convert.rs +++ b/crates/polars-pipe/src/pipeline/convert.rs @@ -1,6 +1,5 @@ use std::cell::RefCell; use std::rc::Rc; -use std::sync::Arc; use hashbrown::hash_map::Entry; use polars_core::prelude::*; @@ -10,16 +9,18 @@ use polars_io::predicates::{PhysicalIoExpr, StatsEvaluator}; use polars_ops::prelude::JoinType; use polars_plan::prelude::*; -use crate::executors::operators::HstackOperator; +use crate::executors::operators::{HstackOperator, PlaceHolder}; use crate::executors::sinks::group_by::aggregates::convert_to_hash_agg; use crate::executors::sinks::group_by::GenericGroupby2; use crate::executors::sinks::*; use crate::executors::{operators, sources}; use crate::expressions::PhysicalPipedExpr; use crate::operators::{Operator, Sink as SinkTrait, Source}; -use crate::pipeline::dispatcher::SinkNode; +use crate::pipeline::dispatcher::ThreadedSink; use crate::pipeline::PipeLine; +pub type CallBacks = PlHashMap; + fn exprs_to_physical( exprs: &[Node], expr_arena: &Arena, @@ -160,6 +161,7 @@ pub fn get_sink( lp_arena: &Arena, expr_arena: &mut Arena, to_physical: &F, + callbacks: &mut CallBacks, ) -> PolarsResult> where F: Fn(Node, &Arena, Option<&SchemaRef>) -> PolarsResult>, @@ -204,29 +206,29 @@ where }, #[cfg(feature = "cloud")] SinkType::Cloud { + #[cfg(any(feature = "parquet", feature = "ipc"))] uri, file_type, + #[cfg(any(feature = "parquet", feature = "ipc"))] cloud_options, + .. } => { - let uri = uri.as_ref().as_str(); - let input_schema = lp_arena.get(*input).schema(lp_arena); - let cloud_options = &cloud_options; match &file_type { #[cfg(feature = "parquet")] FileType::Parquet(parquet_options) => Box::new(ParquetCloudSink::new( - uri, + uri.as_ref().as_str(), cloud_options.as_ref(), *parquet_options, - input_schema.as_ref(), + lp_arena.get(*input).schema(lp_arena).as_ref(), )?) as Box, #[cfg(feature = "ipc")] FileType::Ipc(ipc_options) => Box::new(IpcCloudSink::new( - uri, - cloud_options.as_ref(), - *ipc_options, - input_schema.as_ref(), - )?) + uri.as_ref().as_str(), + cloud_options.as_ref(), + *ipc_options, + lp_arena.get(*input).schema(lp_arena).as_ref(), + )?) as Box, #[allow(unreachable_patterns)] other_file_type => todo!("Cloud-sinking of the file type {other_file_type:?} is not (yet) supported."), @@ -245,12 +247,17 @@ where // slice pushdown optimization should not set this one in a streaming query. assert!(options.args.slice.is_none()); let swapped = swap_join_order(options); + let placeholder = callbacks.get(&node).unwrap().clone(); match &options.args.how { #[cfg(feature = "cross_join")] - JoinType::Cross => Box::new(CrossJoin::new(options.args.suffix().into(), swapped)) - as Box, - join_type @ JoinType::Inner | join_type @ JoinType::Left => { + JoinType::Cross => Box::new(CrossJoin::new( + options.args.suffix().into(), + swapped, + node, + placeholder, + )) as Box, + jt => { let input_schema_left = lp_arena.get(*input_left).schema(lp_arena); let join_columns_left = Arc::new(exprs_to_physical( left_on, @@ -266,22 +273,61 @@ where Some(input_schema_right.as_ref()), )?); - let (join_columns_left, join_columns_right) = if swapped { - (join_columns_right, join_columns_left) - } else { - (join_columns_left, join_columns_right) + let swap_eval = || { + if swapped { + (join_columns_right.clone(), join_columns_left.clone()) + } else { + (join_columns_left.clone(), join_columns_right.clone()) + } }; - Box::new(GenericBuild::new( - Arc::from(options.args.suffix()), - join_type.clone(), - swapped, - join_columns_left, - join_columns_right, - options.args.join_nulls, - )) as Box + match jt { + join_type @ JoinType::Inner | join_type @ JoinType::Left => { + let (join_columns_left, join_columns_right) = swap_eval(); + + Box::new(GenericBuild::<()>::new( + Arc::from(options.args.suffix()), + join_type.clone(), + swapped, + join_columns_left, + join_columns_right, + options.args.join_nulls, + node, + // We don't need the key names for these joins. + vec![].into(), + vec![].into(), + placeholder, + )) as Box + }, + JoinType::Outer { .. } => { + // First get the names before we (potentially) swap. + let key_names_left = join_columns_left + .iter() + .map(|e| e.field(&input_schema_left).unwrap().name) + .collect(); + let key_names_right = join_columns_left + .iter() + .map(|e| e.field(&input_schema_left).unwrap().name) + .collect(); + // Swap. + let (join_columns_left, join_columns_right) = swap_eval(); + + Box::new(GenericBuild::::new( + Arc::from(options.args.suffix()), + jt.clone(), + swapped, + join_columns_left, + join_columns_right, + options.args.join_nulls, + node, + key_names_left, + key_names_right, + placeholder, + )) as Box + }, + _ => unimplemented!(), + } }, - _ => unimplemented!(), } }, Slice { input, offset, len } => { @@ -482,8 +528,8 @@ where Ok(out) } -pub fn get_dummy_operator() -> Box { - Box::new(operators::PlaceHolder {}) +pub fn get_dummy_operator() -> PlaceHolder { + operators::PlaceHolder::new() } fn get_hstack( @@ -606,13 +652,15 @@ where pub fn create_pipeline( sources: &[Node], operators: Vec>, - operator_nodes: Vec, sink_nodes: Vec<(usize, Node, Rc>)>, lp_arena: &Arena, expr_arena: &mut Arena, to_physical: F, verbose: bool, + // Shared sinks are stored in a cache, so that they share state. + // If the shared sink is already in cache, that one is used. sink_cache: &mut PlHashMap>, + callbacks: &mut CallBacks, ) -> PolarsResult where F: Fn(Node, &Arena, Option<&SchemaRef>) -> PolarsResult>, @@ -676,32 +724,29 @@ where // ensure that shared sinks are really shared // to achieve this we store/fetch them in a cache let sink = if *shared_count.borrow() == 1 { - get_sink(node, lp_arena, expr_arena, &to_physical)? + get_sink(node, lp_arena, expr_arena, &to_physical, callbacks)? } else { match sink_cache.entry(node.0) { Entry::Vacant(entry) => { - let sink = get_sink(node, lp_arena, expr_arena, &to_physical)?; + let sink = get_sink(node, lp_arena, expr_arena, &to_physical, callbacks)?; entry.insert(sink.split(0)); sink }, Entry::Occupied(entry) => entry.get().split(0), } }; - Ok(SinkNode::new( + Ok(ThreadedSink::new( sink, shared_count, offset + operator_offset, - node, )) }) .collect::>>()?; Ok(PipeLine::new( source_objects, - operator_objects, - operator_nodes, + unsafe { std::mem::transmute(operator_objects) }, sinks, - operator_offset, verbose, )) } diff --git a/crates/polars-pipe/src/pipeline/dispatcher.rs b/crates/polars-pipe/src/pipeline/dispatcher.rs deleted file mode 100644 index bde8db52806b2..0000000000000 --- a/crates/polars-pipe/src/pipeline/dispatcher.rs +++ /dev/null @@ -1,628 +0,0 @@ -use std::cell::RefCell; -use std::collections::{BTreeSet, VecDeque}; -use std::fmt::{Debug, Formatter}; -use std::rc::Rc; -use std::sync::{Arc, Mutex}; - -use polars_core::error::PolarsResult; -use polars_core::frame::DataFrame; -use polars_core::utils::accumulate_dataframes_vertical_unchecked; -use polars_core::POOL; -use polars_utils::arena::Node; -use polars_utils::sync::SyncPtr; -use rayon::prelude::*; - -use crate::executors::sources::DataFrameSource; -use crate::operators::{ - DataChunk, FinalizedSink, Operator, OperatorResult, PExecutionContext, SExecutionContext, Sink, - SinkResult, Source, SourceResult, -}; -use crate::pipeline::morsels_per_sink; - -pub(super) struct SinkNode { - pub sinks: Vec>, - /// when that hits 0, the sink will finalize - pub shared_count: Rc>, - initial_shared_count: u32, - /// - offset in the operators vec - /// at that point the sink should be called. - /// the pipeline will first call the operators on that point and then - /// push the result in the sink. - pub operator_end: usize, - pub node: Node, -} - -impl SinkNode { - pub fn new( - sink: Box, - shared_count: Rc>, - operator_end: usize, - node: Node, - ) -> Self { - let n_threads = morsels_per_sink(); - let sinks = (0..n_threads).map(|i| sink.split(i)).collect(); - let initial_shared_count = *shared_count.borrow(); - SinkNode { - sinks, - initial_shared_count, - shared_count, - operator_end, - node, - } - } - - // Only the first node of a shared sink should recurse. The others should return. - fn allow_recursion(&self) -> bool { - self.initial_shared_count == *self.shared_count.borrow() - } -} - -/// A pipeline consists of: -/// -/// - 1. One or more sources. -/// Sources get pulled and their data is pushed into operators. -/// - 2. Zero or more operators. -/// The operators simply pass through data, modifying it as they need. -/// Operators can work on batches and don't need all data in scope to -/// succeed. -/// Think for example on multiply a few columns, or applying a predicate. -/// Operators can shrink the batches: filter -/// Grow the batches: explode/ melt -/// Keep them the same size: element-wise operations -/// The probe side of join operations is also an operator. -/// -/// -/// - 3. One or more sinks -/// A sink needs all data in scope to finalize a pipeline branch. -/// Think of sorts, preparing a build phase of a join, group_by + aggregations. -/// -/// This struct will have the SOS (source, operators, sinks) of its own pipeline branch, but also -/// the SOS of other branches. The SOS are stored data oriented and the sinks have an offset that -/// indicates the last operator node before that specific sink. We only store the `end offset` and -/// keep track of the starting operator during execution. -/// -/// Pipelines branches are shared with other pipeline branches at the join/union nodes. -/// # JOIN -/// Consider this tree: -/// out -/// / -/// /\ -/// 1 2 -/// -/// And let's consider that branch 2 runs first. It will run until the join node where it will sink -/// into a build table. Once that is done it will replace the build-phase placeholder operator in -/// branch 1. Branch one can then run completely until out. -pub struct PipeLine { - /// All the sources of this pipeline - sources: Vec>, - /// All the operators of this pipeline. Some may be placeholders that will be replaced during - /// execution - operators: Vec>>, - /// The nodes of operators. These are used to identify operators between pipelines - operator_nodes: Vec, - /// - offset in the operators vec - /// at that point the sink should be called. - /// the pipeline will first call the operators on that point and then - /// push the result in the sink. - /// - shared_count - /// when that hits 0, the sink will finalize - /// - node of the sink - sinks: Vec, - /// are used to identify the sink shared with other pipeline branches - sink_nodes: Vec, - /// Other branch of the pipeline/tree that must be executed - /// after this one has executed. - /// the dispatcher takes care of this. - other_branches: Rc>>, - /// this is a correction as there may be more `operators` than nodes - /// as during construction, source may have inserted operators - operator_offset: usize, - /// Log runtime info to stderr - verbose: bool, -} - -impl PipeLine { - #[allow(clippy::type_complexity)] - pub(super) fn new( - sources: Vec>, - operators: Vec>, - operator_nodes: Vec, - sinks: Vec, - operator_offset: usize, - verbose: bool, - ) -> PipeLine { - debug_assert_eq!(operators.len(), operator_nodes.len() + operator_offset); - // we don't use the power of two partition size here - // we only do that in the sinks itself. - let n_threads = morsels_per_sink(); - - let sink_nodes = sinks.iter().map(|s| s.node).collect(); - // We split so that every thread gets an operator - // every index maps to a chain of operators than can be pushed as a pipeline for one thread - let operators = (0..n_threads) - .map(|i| operators.iter().map(|op| op.split(i)).collect()) - .collect(); - - PipeLine { - sources, - operators, - operator_nodes, - sinks, - sink_nodes, - other_branches: Default::default(), - operator_offset, - verbose, - } - } - - /// Create a pipeline only consisting of a single branch that always finishes with a sink - pub fn new_simple( - sources: Vec>, - operators: Vec>, - sink: Box, - verbose: bool, - ) -> Self { - let operators_len = operators.len(); - Self::new( - sources, - operators, - vec![], - vec![SinkNode::new( - sink, - Rc::new(RefCell::new(1)), - operators_len, - Node::default(), - )], - 0, - verbose, - ) - } - - /// Add a parent - /// This should be in the right order - pub fn with_other_branch(self, rhs: PipeLine) -> Self { - self.other_branches.borrow_mut().push_back(rhs); - self - } - - // returns if operator was successfully replaced - fn replace_operator(&mut self, op: &dyn Operator, node: Node) -> bool { - if let Some(pos) = self.operator_nodes.iter().position(|n| *n == node) { - let pos = pos + self.operator_offset; - for (i, operator_pipe) in &mut self.operators.iter_mut().enumerate() { - operator_pipe[pos] = op.split(i) - } - true - } else { - false - } - } - - /// Take data chunks from the sources and pushes them into the operators + sink. Every operator - /// works thread local. - /// The caller passes an `operator_start`/`operator_end` to indicate which part of the pipeline - /// branch should be executed. - fn par_process_chunks( - &mut self, - chunks: Vec, - sink: &mut [Box], - ec: &PExecutionContext, - operator_start: usize, - operator_end: usize, - src: &mut Box, - ) -> PolarsResult<(Option, SourceResult)> { - debug_assert!(chunks.len() <= sink.len()); - - fn run_operator_pipe( - pipe: &PipeLine, - operator_start: usize, - operator_end: usize, - chunk: DataChunk, - sink: &mut Box, - operator_pipe: &mut [Box], - ec: &PExecutionContext, - ) -> PolarsResult { - // truncate the operators that should run into the current sink. - let operator_pipe = &mut operator_pipe[operator_start..operator_end]; - - if operator_pipe.is_empty() { - sink.sink(ec, chunk) - } else { - pipe.push_operators(chunk, ec, operator_pipe, sink) - } - } - let sink_results = Arc::new(Mutex::new(None)); - let mut next_batches: Option> = None; - let next_batches_ptr = &mut next_batches as *mut Option>; - let next_batches_ptr = unsafe { SyncPtr::new(next_batches_ptr) }; - - // 1. We will iterate the chunks/sinks/operators - // where every iteration belongs to a single thread - // 2. Then we will truncate the pipeline by `start`/`end` - // so that the pipeline represents pipeline that belongs to this sink - // 3. Then we push the data - // # Threading - // Within a rayon scope - // we spawn the jobs. They don't have to finish in any specific order, - // this makes it more lightweight than `par_iter` - - // temporarily take to please the borrow checker - let mut operators = std::mem::take(&mut self.operators); - - // borrow as ref and move into the closure - let pipeline = &*self; - POOL.scope(|s| { - for ((chunk, sink), operator_pipe) in chunks - .into_iter() - .zip(sink.iter_mut()) - .zip(operators.iter_mut()) - { - let sink_results = sink_results.clone(); - s.spawn(move |_| { - let out = run_operator_pipe( - pipeline, - operator_start, - operator_end, - chunk, - sink, - operator_pipe, - ec, - ); - match out { - Ok(SinkResult::Finished) | Err(_) => { - let mut lock = sink_results.lock().unwrap(); - *lock = Some(out) - }, - _ => {}, - } - }) - } - // already get batches on the thread pool - // if one job is finished earlier we can already start that work - s.spawn(|_| { - let out = src.get_batches(ec); - unsafe { - let ptr = next_batches_ptr.get(); - *ptr = Some(out); - } - }) - }); - self.operators = operators; - - let next_batches = next_batches.unwrap()?; - let mut lock = sink_results.lock().unwrap(); - lock.take() - .transpose() - .map(|sink_result| (sink_result, next_batches)) - } - - /// This thread local logic that pushed a data chunk into the operators + sink - /// It can be that a single operator needs to be called multiple times, this is for instance the - /// case with joins that produce many tuples, that's why we keep a stack of `in_process` - /// operators. - fn push_operators( - &self, - chunk: DataChunk, - ec: &PExecutionContext, - operators: &mut [Box], - sink: &mut Box, - ) -> PolarsResult { - debug_assert!(!operators.is_empty()); - - // Stack based operator execution. - let mut in_process = vec![]; - let operator_offset = 0usize; - in_process.push((operator_offset, chunk)); - let mut needs_flush = BTreeSet::new(); - - while let Some((op_i, chunk)) = in_process.pop() { - match operators.get_mut(op_i) { - None => { - if let SinkResult::Finished = sink.sink(ec, chunk)? { - return Ok(SinkResult::Finished); - } - }, - Some(op) => { - match op.execute(ec, &chunk)? { - OperatorResult::Finished(chunk) => { - if op.must_flush() { - let _ = needs_flush.insert(op_i); - } - in_process.push((op_i + 1, chunk)) - }, - OperatorResult::HaveMoreOutPut(output_chunk) => { - // Push the next operator call with the same chunk on the stack - in_process.push((op_i, chunk)); - - // But first push the output in the next operator - // If a join can produce many rows, we want the filter to - // be executed in between, or sink into a slice so that we get - // sink::finished before we grow the stack with ever more coming chunks - in_process.push((op_i + 1, output_chunk)); - }, - OperatorResult::NeedsNewData => { - // done, take another chunk from the stack - }, - } - }, - } - } - - // Stack based flushing + operator execution. - if !needs_flush.is_empty() { - drop(in_process); - let mut in_process = vec![]; - - for op_i in needs_flush.into_iter() { - // Push all operators that need flushing on the stack. - // The `None` indicates that we have no `chunk` input, so we `flush`. - // `Some(chunk)` is the pushing branch - in_process.push((op_i, None)); - - // Next we immediately pop and determine the order of execution below. - // This is to ensure that all operators below upper operators are completely - // flushed when the `flush` is called in higher operators. As operators can `flush` - // multiple times. - while let Some((op_i, chunk)) = in_process.pop() { - match chunk { - // The branch for flushing. - None => { - let op = operators.get_mut(op_i).unwrap(); - match op.flush()? { - OperatorResult::Finished(chunk) => { - // Push the chunk in the next operator. - in_process.push((op_i + 1, Some(chunk))) - }, - OperatorResult::HaveMoreOutPut(chunk) => { - // Ensure it is flushed again - in_process.push((op_i, None)); - // Push the chunk in the next operator. - in_process.push((op_i + 1, Some(chunk))) - }, - _ => unreachable!(), - } - }, - // The branch for pushing data in the operators. - // This is the same as the default stack exectuor, except now it pushes - // `Some(chunk)` instead of `chunk`. - Some(chunk) => { - match operators.get_mut(op_i) { - None => { - if let SinkResult::Finished = sink.sink(ec, chunk)? { - return Ok(SinkResult::Finished); - } - }, - Some(op) => { - match op.execute(ec, &chunk)? { - OperatorResult::Finished(chunk) => { - in_process.push((op_i + 1, Some(chunk))) - }, - OperatorResult::HaveMoreOutPut(output_chunk) => { - // Push the next operator call with the same chunk on the stack - in_process.push((op_i, Some(chunk))); - - // But first push the output in the next operator - // If a join can produce many rows, we want the filter to - // be executed in between, or sink into a slice so that we get - // sink::finished before we grow the stack with ever more coming chunks - in_process.push((op_i + 1, Some(output_chunk))); - }, - OperatorResult::NeedsNewData => { - // Done, take another chunk from the stack - }, - } - }, - } - }, - } - } - } - } - - Ok(SinkResult::CanHaveMoreInput) - } - - /// Replace the current sources with a [`DataFrameSource`]. - fn set_df_as_sources(&mut self, df: DataFrame) { - let src = Box::new(DataFrameSource::from_df(df)) as Box; - self.set_sources(src) - } - - /// Replace the current sources. - fn set_sources(&mut self, src: Box) { - self.sources.clear(); - self.sources.push(src); - } - - fn run_pipeline_no_finalize( - &mut self, - ec: &PExecutionContext, - pipeline_q: Rc>>, - ) -> PolarsResult<(u32, Box)> { - let mut out = None; - let mut operator_start = 0; - let last_i = self.sinks.len() - 1; - - // for unions we typically first want to push all pipelines - // into the union sink before we call `finalize` - // however if the sink is finished early, (for instance a `head`) - // we don't want to run the rest of the pipelines and we finalize early - let mut sink_finished = false; - - for (i, mut sink) in std::mem::take(&mut self.sinks).into_iter().enumerate() { - for src in &mut std::mem::take(&mut self.sources) { - let mut next_batches = src.get_batches(ec)?; - - while let SourceResult::GotMoreData(chunks) = next_batches { - // Every batches iteration we check if we must continue. - ec.execution_state.should_stop()?; - - let (sink_result, next_batches2) = self.par_process_chunks( - chunks, - &mut sink.sinks, - ec, - operator_start, - sink.operator_end, - src, - )?; - next_batches = next_batches2; - - if let Some(SinkResult::Finished) = sink_result { - sink_finished = true; - break; - } - } - } - - // Before we reduce we also check if we should continue. - ec.execution_state.should_stop()?; - let allow_recursion = sink.allow_recursion(); - - // The sinks have taken all chunks thread locally, now we reduce them into a single - // result sink. - let mut reduced_sink = POOL - .install(|| { - sink.sinks.into_par_iter().reduce_with(|mut a, mut b| { - a.combine(&mut *b); - a - }) - }) - .unwrap(); - operator_start = sink.operator_end; - - let mut shared_sink_count = { - let mut shared_sink_count = sink.shared_count.borrow_mut(); - *shared_sink_count -= 1; - *shared_sink_count - }; - - // Prevent very deep recursion. Only the outer callee can pop and run. - if allow_recursion { - while shared_sink_count > 0 && !sink_finished { - let mut pipeline = pipeline_q.borrow_mut().pop_front().unwrap(); - let (count, mut sink) = - pipeline.run_pipeline_no_finalize(ec, pipeline_q.clone())?; - reduced_sink.combine(sink.as_mut()); - shared_sink_count = count; - } - } - - if i != last_i { - let sink_result = reduced_sink.finalize(ec)?; - match sink_result { - // turn this sink an a new source - FinalizedSink::Finished(df) => self.set_df_as_sources(df), - FinalizedSink::Source(src) => self.set_sources(src), - // should not happen - FinalizedSink::Operator(_) => { - unreachable!() - }, - } - } else { - out = Some((shared_sink_count, reduced_sink)) - } - } - Ok(out.unwrap()) - } - - /// Run a single pipeline branch. - /// This pulls data from the sources and pushes it into the operators which run on a different - /// thread and finalize in a sink. - /// - /// The sink can be finished, but can also become a new source and then rinse and repeat. - pub fn run_pipeline( - &mut self, - ec: &PExecutionContext, - pipeline_q: Rc>>, - ) -> PolarsResult> { - let (sink_shared_count, mut reduced_sink) = - self.run_pipeline_no_finalize(ec, pipeline_q)?; - assert_eq!(sink_shared_count, 0); - Ok(reduced_sink.finalize(ec).ok()) - } - - /// Executes all branches and replaces operators and sinks during execution to ensure - /// we materialize. - pub fn execute(&mut self, state: Box) -> PolarsResult { - let ec = PExecutionContext::new(state, self.verbose); - - if self.verbose { - eprintln!("{self:?}"); - eprintln!("{:?}", &self.other_branches); - } - let mut sink_out = self.run_pipeline(&ec, self.other_branches.clone())?; - let mut sink_nodes = std::mem::take(&mut self.sink_nodes); - loop { - match &mut sink_out { - None => { - let mut pipeline = self.other_branches.borrow_mut().pop_front().unwrap(); - sink_out = pipeline.run_pipeline(&ec, self.other_branches.clone())?; - sink_nodes = std::mem::take(&mut pipeline.sink_nodes); - }, - Some(FinalizedSink::Finished(df)) => return Ok(std::mem::take(df)), - Some(FinalizedSink::Source(src)) => return consume_source(&mut **src, &ec), - - // - // 1/\ - // 2/\ - // 3\ - // the left hand side of the join has finished and now is an operator - // we replace the dummy node in the right hand side pipeline with this - // operator and then we run the pipeline rinse and repeat - // until the final right hand side pipeline ran - Some(FinalizedSink::Operator(op)) => { - // we unwrap, because the latest pipeline should not return an Operator - let mut pipeline = self.other_branches.borrow_mut().pop_front().unwrap(); - - // latest sink_node will be the operator, as the left side of the join - // always finishes that branch. - if let Some(sink_node) = sink_nodes.pop() { - // we traverse all pipeline - pipeline.replace_operator(op.as_ref(), sink_node); - // if there are unions, there can be more - for pl in self.other_branches.borrow_mut().iter_mut() { - pl.replace_operator(op.as_ref(), sink_node); - } - } - sink_out = pipeline.run_pipeline(&ec, self.other_branches.clone())?; - sink_nodes = std::mem::take(&mut pipeline.sink_nodes); - }, - } - } - } -} - -impl Debug for PipeLine { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - let mut fmt = String::new(); - let mut start = 0usize; - fmt.push_str(self.sources[0].fmt()); - for sink in &self.sinks { - fmt.push_str(" -> "); - // take operators of a single thread - let ops = &self.operators[0]; - // slice the pipeline - let ops = &ops[start..sink.operator_end]; - for op in ops { - fmt.push_str(op.fmt()); - fmt.push_str(" -> ") - } - start = sink.operator_end; - fmt.push_str(sink.sinks[0].fmt()) - } - write!(f, "{fmt}") - } -} - -/// Take a source and materialize it into a [`DataFrame`]. -fn consume_source(src: &mut dyn Source, context: &PExecutionContext) -> PolarsResult { - let mut frames = Vec::with_capacity(32); - - while let SourceResult::GotMoreData(batch) = src.get_batches(context)? { - frames.extend(batch.into_iter().map(|chunk| chunk.data)) - } - Ok(accumulate_dataframes_vertical_unchecked(frames)) -} - -unsafe impl Send for PipeLine {} -unsafe impl Sync for PipeLine {} diff --git a/crates/polars-pipe/src/pipeline/dispatcher/drive_operator.rs b/crates/polars-pipe/src/pipeline/dispatcher/drive_operator.rs new file mode 100644 index 0000000000000..96b263351a3d2 --- /dev/null +++ b/crates/polars-pipe/src/pipeline/dispatcher/drive_operator.rs @@ -0,0 +1,254 @@ +use super::*; +use crate::pipeline::*; + +/// Take data chunks from the sources and pushes them into the operators + sink. Every operator +/// works thread local. +/// The caller passes an `operator_start`/`operator_end` to indicate which part of the pipeline +/// branch should be executed. +#[allow(clippy::too_many_arguments)] +pub(super) fn par_process_chunks( + chunks: Vec, + sink: ThreadedSinkMut, + ec: &PExecutionContext, + operators: &mut [ThreadedOperator], + operator_start: usize, + operator_end: usize, + src: &mut Box, + must_flush: &AtomicBool, +) -> PolarsResult<(Option, SourceResult)> { + debug_assert!(chunks.len() <= sink.len()); + let sink_results = Arc::new(Mutex::new(None)); + let mut next_batches: Option> = None; + let next_batches_ptr = &mut next_batches as *mut Option>; + let next_batches_ptr = unsafe { SyncPtr::new(next_batches_ptr) }; + + // 1. We will iterate the chunks/sinks/operators + // where every iteration belongs to a single thread + // 2. Then we will truncate the pipeline by `start`/`end` + // so that the pipeline represents pipeline that belongs to this sink + // 3. Then we push the data + // # Threading + // Within a rayon scope + // we spawn the jobs. They don't have to finish in any specific order, + // this makes it more lightweight than `par_iter` + + // borrow as ref and move into the closure + POOL.scope(|s| { + for ((chunk, sink), operator_pipe) in chunks + .into_iter() + .zip(sink.iter_mut()) + .zip(operators.iter_mut()) + { + let sink_results = sink_results.clone(); + // Truncate the operators that should run into the current sink. + let operator_pipe = &mut operator_pipe[operator_start..operator_end]; + + s.spawn(move |_| { + let out = if operator_pipe.is_empty() { + sink.sink(ec, chunk) + } else { + push_operators_single_thread(chunk, ec, operator_pipe, sink, must_flush) + }; + + match out { + Ok(SinkResult::Finished) | Err(_) => { + let mut lock = sink_results.lock().unwrap(); + *lock = Some(out) + }, + _ => {}, + } + }) + } + // already get batches on the thread pool + // if one job is finished earlier we can already start that work + s.spawn(|_| { + let out = src.get_batches(ec); + unsafe { + let ptr = next_batches_ptr.get(); + *ptr = Some(out); + } + }) + }); + + let next_batches = next_batches.unwrap()?; + let mut lock = sink_results.lock().unwrap(); + lock.take() + .transpose() + .map(|sink_result| (sink_result, next_batches)) +} + +/// This thread local logic that pushed a data chunk into the operators + sink +/// It can be that a single operator needs to be called multiple times, this is for instance the +/// case with joins that produce many tuples, that's why we keep a stack of `in_process` +/// operators. +pub(super) fn push_operators_single_thread( + chunk: DataChunk, + ec: &PExecutionContext, + operators: ThreadedOperatorMut, + sink: &mut Box, + must_flush: &AtomicBool, +) -> PolarsResult { + debug_assert!(!operators.is_empty()); + + // Stack based operator execution. + let mut in_process = vec![]; + let operator_offset = 0usize; + in_process.push((operator_offset, chunk)); + + while let Some((op_i, chunk)) = in_process.pop() { + match operators.get_mut(op_i) { + None => { + if let SinkResult::Finished = sink.sink(ec, chunk)? { + return Ok(SinkResult::Finished); + } + }, + Some(op) => { + let op = op.get_mut(); + match op.execute(ec, &chunk)? { + OperatorResult::Finished(chunk) => { + must_flush.store(op.must_flush(), Ordering::Relaxed); + in_process.push((op_i + 1, chunk)) + }, + OperatorResult::HaveMoreOutPut(output_chunk) => { + // Push the next operator call with the same chunk on the stack + in_process.push((op_i, chunk)); + + // But first push the output in the next operator + // If a join can produce many rows, we want the filter to + // be executed in between, or sink into a slice so that we get + // sink::finished before we grow the stack with ever more coming chunks + in_process.push((op_i + 1, output_chunk)); + }, + OperatorResult::NeedsNewData => { + // done, take another chunk from the stack + }, + } + }, + } + } + + Ok(SinkResult::CanHaveMoreInput) +} + +/// Similar to `par_process_chunks`. +/// The caller passes an `operator_start`/`operator_end` to indicate which part of the pipeline +/// branch should be executed. +pub(super) fn par_flush( + sink: ThreadedSinkMut, + ec: &PExecutionContext, + operators: &mut [ThreadedOperator], + operator_start: usize, + operator_end: usize, +) { + // 1. We will iterate the chunks/sinks/operators + // where every iteration belongs to a single thread + // 2. Then we will truncate the pipeline by `start`/`end` + // so that the pipeline represents pipeline that belongs to this sink + // 3. Then we push the data + // # Threading + // Within a rayon scope + // we spawn the jobs. They don't have to finish in any specific order, + // this makes it more lightweight than `par_iter` + + // borrow as ref and move into the closure + POOL.scope(|s| { + for (sink, operator_pipe) in sink.iter_mut().zip(operators.iter_mut()) { + // Truncate the operators that should run into the current sink. + let operator_pipe = &mut operator_pipe[operator_start..operator_end]; + + s.spawn(move |_| { + flush_operators(ec, operator_pipe, sink).unwrap(); + }) + } + }); +} + +pub(super) fn flush_operators( + ec: &PExecutionContext, + operators: &mut [PhysOperator], + sink: &mut Box, +) -> PolarsResult { + let needs_flush = operators + .iter_mut() + .enumerate() + .filter_map(|(i, op)| { + if op.get_mut().must_flush() { + Some(i) + } else { + None + } + }) + .collect::>(); + + // Stack based flushing + operator execution. + if !needs_flush.is_empty() { + let mut in_process = vec![]; + + for op_i in needs_flush.into_iter() { + // Push all operators that need flushing on the stack. + // The `None` indicates that we have no `chunk` input, so we `flush`. + // `Some(chunk)` is the pushing branch + in_process.push((op_i, None)); + + // Next we immediately pop and determine the order of execution below. + // This is to ensure that all operators below upper operators are completely + // flushed when the `flush` is called in higher operators. As operators can `flush` + // multiple times. + while let Some((op_i, chunk)) = in_process.pop() { + match chunk { + // The branch for flushing. + None => { + let op = operators.get_mut(op_i).unwrap().get_mut(); + match op.flush()? { + OperatorResult::Finished(chunk) => { + // Push the chunk in the next operator. + in_process.push((op_i + 1, Some(chunk))) + }, + OperatorResult::HaveMoreOutPut(chunk) => { + // Ensure it is flushed again + in_process.push((op_i, None)); + // Push the chunk in the next operator. + in_process.push((op_i + 1, Some(chunk))) + }, + _ => unreachable!(), + } + }, + // The branch for pushing data in the operators. + // This is the same as the default stack exectuor, except now it pushes + // `Some(chunk)` instead of `chunk`. + Some(chunk) => { + match operators.get_mut(op_i) { + None => { + if let SinkResult::Finished = sink.sink(ec, chunk)? { + return Ok(SinkResult::Finished); + } + }, + Some(op) => { + let op = op.get_mut(); + match op.execute(ec, &chunk)? { + OperatorResult::Finished(chunk) => { + in_process.push((op_i + 1, Some(chunk))) + }, + OperatorResult::HaveMoreOutPut(output_chunk) => { + // Push the next operator call with the same chunk on the stack + in_process.push((op_i, Some(chunk))); + + // But first push the output in the next operator + // If a join can produce many rows, we want the filter to + // be executed in between, or sink into a slice so that we get + // sink::finished before we grow the stack with ever more coming chunks + in_process.push((op_i + 1, Some(output_chunk))); + }, + OperatorResult::NeedsNewData => { + // Done, take another chunk from the stack + }, + } + }, + } + }, + } + } + } + } + Ok(SinkResult::Finished) +} diff --git a/crates/polars-pipe/src/pipeline/dispatcher/mod.rs b/crates/polars-pipe/src/pipeline/dispatcher/mod.rs new file mode 100644 index 0000000000000..a60f1efb80641 --- /dev/null +++ b/crates/polars-pipe/src/pipeline/dispatcher/mod.rs @@ -0,0 +1,380 @@ +use std::cell::RefCell; +use std::fmt::{Debug, Formatter}; +use std::rc::Rc; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Mutex}; + +use polars_core::error::PolarsResult; +use polars_core::utils::accumulate_dataframes_vertical_unchecked; +use polars_core::POOL; +use polars_utils::sync::SyncPtr; +use rayon::prelude::*; + +use crate::executors::sources::DataFrameSource; +use crate::operators::{ + DataChunk, FinalizedSink, OperatorResult, PExecutionContext, SExecutionContext, Sink, + SinkResult, Source, SourceResult, +}; +use crate::pipeline::dispatcher::drive_operator::{par_flush, par_process_chunks}; +mod drive_operator; +use super::*; + +pub(super) struct ThreadedSink { + /// A sink split per thread. + pub sinks: Vec>, + /// when that hits 0, the sink will finalize + pub shared_count: Rc>, + initial_shared_count: u32, + /// - offset in the operators vec + /// at that point the sink should be called. + /// the pipeline will first call the operators on that point and then + /// push the result in the sink. + pub operator_end: usize, +} + +impl ThreadedSink { + pub fn new(sink: Box, shared_count: Rc>, operator_end: usize) -> Self { + let n_threads = morsels_per_sink(); + let sinks = (0..n_threads).map(|i| sink.split(i)).collect(); + let initial_shared_count = *shared_count.borrow(); + ThreadedSink { + sinks, + initial_shared_count, + shared_count, + operator_end, + } + } + + // Only the first node of a shared sink should recurse. The others should return. + fn allow_recursion(&self) -> bool { + self.initial_shared_count == *self.shared_count.borrow() + } +} + +/// A pipeline consists of: +/// +/// - 1. One or more sources. +/// Sources get pulled and their data is pushed into operators. +/// - 2. Zero or more operators. +/// The operators simply pass through data, modifying it as they need. +/// Operators can work on batches and don't need all data in scope to +/// succeed. +/// Think for example on multiply a few columns, or applying a predicate. +/// Operators can shrink the batches: filter +/// Grow the batches: explode/ melt +/// Keep them the same size: element-wise operations +/// The probe side of join operations is also an operator. +/// +/// +/// - 3. One or more sinks +/// A sink needs all data in scope to finalize a pipeline branch. +/// Think of sorts, preparing a build phase of a join, group_by + aggregations. +/// +/// This struct will have the SOS (source, operators, sinks) of its own pipeline branch, but also +/// the SOS of other branches. The SOS are stored data oriented and the sinks have an offset that +/// indicates the last operator node before that specific sink. We only store the `end offset` and +/// keep track of the starting operator during execution. +/// +/// Pipelines branches are shared with other pipeline branches at the join/union nodes. +/// # JOIN +/// Consider this tree: +/// out +/// / +/// /\ +/// 1 2 +/// +/// And let's consider that branch 2 runs first. It will run until the join node where it will sink +/// into a build table. Once that is done it will replace the build-phase placeholder operator in +/// branch 1. Branch one can then run completely until out. +pub struct PipeLine { + /// All the sources of this pipeline + sources: Vec>, + /// All the operators of this pipeline. Some may be placeholders that will be replaced during + /// execution + operators: Vec, + /// - offset in the operators vec + /// at that point the sink should be called. + /// the pipeline will first call the operators on that point and then + /// push the result in the sink. + /// - shared_count + /// when that hits 0, the sink will finalize + /// - node of the sink + sinks: Vec, + /// Log runtime info to stderr + verbose: bool, +} + +impl PipeLine { + #[allow(clippy::type_complexity)] + pub(super) fn new( + sources: Vec>, + operators: Vec, + sinks: Vec, + verbose: bool, + ) -> PipeLine { + // we don't use the power of two partition size here + // we only do that in the sinks itself. + let n_threads = morsels_per_sink(); + + // We split so that every thread gets an operator + // every index maps to a chain of operators than can be pushed as a pipeline for one thread + let operators = (0..n_threads) + .map(|i| { + operators + .iter() + .map(|op| op.get_ref().split(i).into()) + .collect() + }) + .collect(); + + PipeLine { + sources, + operators, + sinks, + verbose, + } + } + + /// Create a pipeline only consisting of a single branch that always finishes with a sink + pub(crate) fn new_simple( + sources: Vec>, + operators: Vec, + sink: Box, + verbose: bool, + ) -> Self { + let operators_len = operators.len(); + Self::new( + sources, + operators, + vec![ThreadedSink::new( + sink, + Rc::new(RefCell::new(1)), + operators_len, + )], + verbose, + ) + } + + /// Replace the current sources with a [`DataFrameSource`]. + fn set_df_as_sources(&mut self, df: DataFrame) { + let src = Box::new(DataFrameSource::from_df(df)) as Box; + self.set_sources(src) + } + + /// Replace the current sources. + fn set_sources(&mut self, src: Box) { + self.sources.clear(); + self.sources.push(src); + } + + fn run_pipeline_no_finalize( + &mut self, + ec: &PExecutionContext, + pipelines: &mut Vec, + ) -> PolarsResult<(u32, Box)> { + let mut out = None; + let mut operator_start = 0; + let last_i = self.sinks.len() - 1; + + // For unions we typically first want to push all pipelines + // into the union sink before we call `finalize` + // however if the sink is finished early, (for instance a `head`) + // we don't want to run the rest of the pipelines and we finalize early + let mut sink_finished = false; + + for (i, mut sink) in std::mem::take(&mut self.sinks).into_iter().enumerate() { + for src in &mut std::mem::take(&mut self.sources) { + let mut next_batches = src.get_batches(ec)?; + + let must_flush: AtomicBool = AtomicBool::new(false); + while let SourceResult::GotMoreData(chunks) = next_batches { + // Every batches iteration we check if we must continue. + ec.execution_state.should_stop()?; + + let (sink_result, next_batches2) = par_process_chunks( + chunks, + &mut sink.sinks, + ec, + &mut self.operators, + operator_start, + sink.operator_end, + src, + &must_flush, + )?; + next_batches = next_batches2; + + if let Some(SinkResult::Finished) = sink_result { + sink_finished = true; + break; + } + } + if !sink_finished && must_flush.load(Ordering::Relaxed) { + par_flush( + &mut sink.sinks, + ec, + &mut self.operators, + operator_start, + sink.operator_end, + ); + } + } + + // Before we reduce we also check if we should continue. + ec.execution_state.should_stop()?; + let allow_recursion = sink.allow_recursion(); + + // The sinks have taken all chunks thread locally, now we reduce them into a single + // result sink. + let mut reduced_sink = POOL + .install(|| { + sink.sinks.into_par_iter().reduce_with(|mut a, mut b| { + a.combine(&mut *b); + a + }) + }) + .unwrap(); + operator_start = sink.operator_end; + + let mut shared_sink_count = { + let mut shared_sink_count = sink.shared_count.borrow_mut(); + *shared_sink_count -= 1; + *shared_sink_count + }; + + // Prevent very deep recursion. Only the outer callee can pop and run. + if allow_recursion { + while shared_sink_count > 0 && !sink_finished { + let mut pipeline = pipelines.pop().unwrap(); + let (count, mut sink) = pipeline.run_pipeline_no_finalize(ec, pipelines)?; + // This branch is hit when we have a Union of joins. + // The build side must be converted into an operator and replaced in the next pipeline. + + // Check either: + // 1. There can be a union source that sinks into a single join: + // scan_parquet(*) -> join B + // 2. There can be a union of joins + // C - JOIN A, B + // concat (A, B, C) + // + // So to ensure that we don't finalize we check + // - They are not both join builds + // - If they are both join builds, check they are note the same build, otherwise + // we must call the `combine` branch. + if sink.is_join_build() + && (!reduced_sink.is_join_build() || (sink.node() != reduced_sink.node())) + { + let FinalizedSink::Operator = sink.finalize(ec)? else { + unreachable!() + }; + } else { + reduced_sink.combine(sink.as_mut()); + shared_sink_count = count; + } + } + } + + if i != last_i { + let sink_result = reduced_sink.finalize(ec)?; + match sink_result { + // turn this sink an a new source + FinalizedSink::Finished(df) => self.set_df_as_sources(df), + FinalizedSink::Source(src) => self.set_sources(src), + // should not happen + FinalizedSink::Operator => { + unreachable!() + }, + } + } else { + out = Some((shared_sink_count, reduced_sink)) + } + } + Ok(out.unwrap()) + } + + /// Run a single pipeline branch. + /// This pulls data from the sources and pushes it into the operators which run on a different + /// thread and finalize in a sink. + /// + /// The sink can be finished, but can also become a new source and then rinse and repeat. + pub fn run_pipeline( + &mut self, + ec: &PExecutionContext, + pipelines: &mut Vec, + ) -> PolarsResult> { + let (sink_shared_count, mut reduced_sink) = self.run_pipeline_no_finalize(ec, pipelines)?; + assert_eq!(sink_shared_count, 0); + Ok(reduced_sink.finalize(ec).ok()) + } +} + +/// Executes all branches and replaces operators and sinks during execution to ensure +/// we materialize. +pub fn execute_pipeline( + state: Box, + mut pipelines: Vec, +) -> PolarsResult { + let mut pipeline = pipelines.pop().unwrap(); + let ec = PExecutionContext::new(state, pipeline.verbose); + + let mut sink_out = pipeline.run_pipeline(&ec, &mut pipelines)?; + loop { + match &mut sink_out { + None => { + let mut pipeline = pipelines.pop().unwrap(); + sink_out = pipeline.run_pipeline(&ec, &mut pipelines)?; + }, + Some(FinalizedSink::Finished(df)) => return Ok(std::mem::take(df)), + Some(FinalizedSink::Source(src)) => return consume_source(&mut **src, &ec), + + // + // 1/\ + // 2/\ + // 3\ + // the left hand side of the join has finished and now is an operator + // we replace the dummy node in the right hand side pipeline with this + // operator and then we run the pipeline rinse and repeat + // until the final right hand side pipeline ran + Some(FinalizedSink::Operator) => { + // we unwrap, because the latest pipeline should not return an Operator + let mut pipeline = pipelines.pop().unwrap(); + + sink_out = pipeline.run_pipeline(&ec, &mut pipelines)?; + }, + } + } +} + +impl Debug for PipeLine { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let mut fmt = String::new(); + let mut start = 0usize; + fmt.push_str(self.sources[0].fmt()); + for sink in &self.sinks { + fmt.push_str(" -> "); + // take operators of a single thread + let ops = &self.operators[0]; + // slice the pipeline + let ops = &ops[start..sink.operator_end]; + for op in ops { + fmt.push_str(op.get_ref().fmt()); + fmt.push_str(" -> ") + } + start = sink.operator_end; + fmt.push_str(sink.sinks[0].fmt()) + } + write!(f, "{fmt}") + } +} + +/// Take a source and materialize it into a [`DataFrame`]. +fn consume_source(src: &mut dyn Source, context: &PExecutionContext) -> PolarsResult { + let mut frames = Vec::with_capacity(32); + + while let SourceResult::GotMoreData(batch) = src.get_batches(context)? { + frames.extend(batch.into_iter().map(|chunk| chunk.data)) + } + Ok(accumulate_dataframes_vertical_unchecked(frames)) +} + +unsafe impl Send for PipeLine {} +unsafe impl Sync for PipeLine {} diff --git a/crates/polars-pipe/src/pipeline/mod.rs b/crates/polars-pipe/src/pipeline/mod.rs index f61d5e1b329eb..45df2f2df233a 100644 --- a/crates/polars-pipe/src/pipeline/mod.rs +++ b/crates/polars-pipe/src/pipeline/mod.rs @@ -2,12 +2,16 @@ mod config; mod convert; mod dispatcher; -pub use convert::{create_pipeline, get_dummy_operator, get_operator, get_sink, swap_join_order}; -pub use dispatcher::PipeLine; +pub use convert::{ + create_pipeline, get_dummy_operator, get_operator, get_sink, swap_join_order, CallBacks, +}; +pub use dispatcher::{execute_pipeline, PipeLine}; use polars_core::prelude::*; use polars_core::POOL; +use polars_utils::cell::SyncUnsafeCell; pub use crate::executors::sinks::group_by::aggregates::can_convert_to_hash_agg; +use crate::operators::{Operator, Sink}; pub(crate) fn morsels_per_sink() -> usize { POOL.current_num_threads() @@ -33,3 +37,32 @@ pub(crate) fn determine_chunk_size(n_cols: usize, n_threads: usize) -> PolarsRes Ok(std::cmp::max(50_000 / n_cols.max(1) * thread_factor, 1000)) } } + +type PhysSink = Box; +/// A physical operator/sink per thread. +type ThreadedOperator = Vec; +type ThreadedOperatorMut<'a> = &'a mut [PhysOperator]; +type ThreadedSinkMut<'a> = &'a mut [PhysSink]; + +#[repr(transparent)] +pub(crate) struct PhysOperator { + inner: SyncUnsafeCell>, +} + +impl From> for PhysOperator { + fn from(value: Box) -> Self { + Self { + inner: SyncUnsafeCell::new(value), + } + } +} + +impl PhysOperator { + pub(crate) fn get_mut(&mut self) -> &mut dyn Operator { + &mut **self.inner.get_mut() + } + + pub(crate) fn get_ref(&self) -> &dyn Operator { + unsafe { &**self.inner.get() } + } +} diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index 93a16b359978c..27e84085e6d7d 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -54,7 +54,7 @@ serde = [ ] streaming = [] parquet = ["polars-io/parquet", "polars-parquet"] -async = ["polars-io/async"] +async = ["polars-io/async", "futures"] cloud = ["async", "polars-io/cloud"] ipc = ["polars-io/ipc"] json = ["polars-io/json", "polars-json"] @@ -107,7 +107,7 @@ is_last_distinct = ["polars-core/is_last_distinct", "polars-ops/is_last_distinct is_unique = ["polars-ops/is_unique"] is_between = ["polars-ops/is_between"] cross_join = ["polars-ops/cross_join"] -asof_join = ["polars-core/asof_join", "polars-time", "polars-ops/asof_join"] +asof_join = ["polars-time", "polars-ops/asof_join"] concat_str = [] range = [] mode = ["polars-ops/mode"] diff --git a/crates/polars-plan/src/dot.rs b/crates/polars-plan/src/dot.rs index 608d0a67e6344..735813721e848 100644 --- a/crates/polars-plan/src/dot.rs +++ b/crates/polars-plan/src/dot.rs @@ -5,7 +5,6 @@ use std::path::PathBuf; use polars_core::prelude::*; use crate::prelude::*; -use crate::utils::expr_to_leaf_column_names; impl Expr { /// Get a dot language representation of the Expression. diff --git a/crates/polars-plan/src/dsl/array.rs b/crates/polars-plan/src/dsl/array.rs index 1e73613c7d041..b00347ba80072 100644 --- a/crates/polars-plan/src/dsl/array.rs +++ b/crates/polars-plan/src/dsl/array.rs @@ -4,7 +4,7 @@ use polars_ops::chunked_array::array::{ arr_default_struct_name_gen, ArrToStructNameGenerator, ToStruct, }; -use crate::dsl::function_expr::{ArrayFunction, FunctionExpr}; +use crate::dsl::function_expr::ArrayFunction; use crate::prelude::*; /// Specialized expressions for [`Series`] of [`DataType::Array`]. diff --git a/crates/polars-plan/src/dsl/binary.rs b/crates/polars-plan/src/dsl/binary.rs index c8a6dc3683910..1a395ac1cb6fa 100644 --- a/crates/polars-plan/src/dsl/binary.rs +++ b/crates/polars-plan/src/dsl/binary.rs @@ -1,4 +1,3 @@ -use super::function_expr::BinaryFunction; use super::*; /// Specialized expressions for [`Series`] of [`DataType::String`]. pub struct BinaryNameSpace(pub(crate) Expr); diff --git a/crates/polars-plan/src/dsl/expr.rs b/crates/polars-plan/src/dsl/expr.rs index b91deca57e9df..cf7cc2a31fbcf 100644 --- a/crates/polars-plan/src/dsl/expr.rs +++ b/crates/polars-plan/src/dsl/expr.rs @@ -6,7 +6,6 @@ use polars_core::prelude::*; use serde::{Deserialize, Serialize}; pub use super::expr_dyn_fn::*; -use crate::dsl::function_expr::FunctionExpr; use crate::prelude::*; #[derive(PartialEq, Clone)] diff --git a/crates/polars-plan/src/dsl/from.rs b/crates/polars-plan/src/dsl/from.rs index e815fdb7ffe49..eeaa631521cbc 100644 --- a/crates/polars-plan/src/dsl/from.rs +++ b/crates/polars-plan/src/dsl/from.rs @@ -6,8 +6,6 @@ impl From for Expr { } } -pub trait RefString {} - impl From<&str> for Expr { fn from(s: &str) -> Self { col(s) diff --git a/crates/polars-plan/src/dsl/function_expr/fill_null.rs b/crates/polars-plan/src/dsl/function_expr/fill_null.rs index f7ed47f77be69..686d0a36cd303 100644 --- a/crates/polars-plan/src/dsl/function_expr/fill_null.rs +++ b/crates/polars-plan/src/dsl/function_expr/fill_null.rs @@ -23,11 +23,7 @@ pub(super) fn fill_null(s: &[Series], super_type: &DataType) -> PolarsResult PolarsResult { - // broadcast to the proper length for zip_with - if fill_value.len() == 1 && series.len() != 1 { - fill_value = fill_value.new_from_index(0, series.len()); - } + fn default(series: Series, fill_value: Series) -> PolarsResult { let mask = series.is_not_null(); series.zip_with_same_type(&mask, &fill_value) } diff --git a/crates/polars-plan/src/dsl/function_expr/fused.rs b/crates/polars-plan/src/dsl/function_expr/fused.rs index f9cb93c1b560b..a95ac809ebc77 100644 --- a/crates/polars-plan/src/dsl/function_expr/fused.rs +++ b/crates/polars-plan/src/dsl/function_expr/fused.rs @@ -1,5 +1,3 @@ -use std::fmt::{Display, Formatter}; - #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index cd5942ed40c63..397959e9980f7 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -73,10 +73,6 @@ pub(crate) use correlation::CorrelationMethod; pub(crate) use fused::FusedOperator; pub(super) use list::ListFunction; use polars_core::prelude::*; -#[cfg(feature = "cutqcut")] -use polars_ops::prelude::{cut, qcut}; -#[cfg(feature = "rle")] -use polars_ops::prelude::{rle, rle_id}; #[cfg(feature = "random")] pub(crate) use random::RandomMethod; use schema::FieldsMapper; diff --git a/crates/polars-plan/src/dsl/function_expr/range/date_range.rs b/crates/polars-plan/src/dsl/function_expr/range/date_range.rs index bacda8bc45ae3..1dd96e3f6af4b 100644 --- a/crates/polars-plan/src/dsl/function_expr/range/date_range.rs +++ b/crates/polars-plan/src/dsl/function_expr/range/date_range.rs @@ -1,5 +1,4 @@ use polars_core::prelude::*; -use polars_core::series::Series; use polars_core::utils::arrow::temporal_conversions::MILLISECONDS_IN_DAY; use polars_time::{datetime_range_impl, ClosedWindow, Duration}; diff --git a/crates/polars-plan/src/dsl/function_expr/range/datetime_range.rs b/crates/polars-plan/src/dsl/function_expr/range/datetime_range.rs index 10cca7f3ccf6d..3c61e60259c44 100644 --- a/crates/polars-plan/src/dsl/function_expr/range/datetime_range.rs +++ b/crates/polars-plan/src/dsl/function_expr/range/datetime_range.rs @@ -1,7 +1,6 @@ #[cfg(feature = "timezones")] use polars_core::chunked_array::temporal::parse_time_zone; use polars_core::prelude::*; -use polars_core::series::Series; use polars_time::{datetime_range_impl, ClosedWindow, Duration}; use super::utils::{ diff --git a/crates/polars-plan/src/dsl/function_expr/range/int_range.rs b/crates/polars-plan/src/dsl/function_expr/range/int_range.rs index 3b4206e3ae0a4..5344ec0b5ee88 100644 --- a/crates/polars-plan/src/dsl/function_expr/range/int_range.rs +++ b/crates/polars-plan/src/dsl/function_expr/range/int_range.rs @@ -1,5 +1,4 @@ use polars_core::prelude::*; -use polars_core::series::Series; use polars_core::with_match_physical_integer_polars_type; use polars_ops::series::new_int_range; diff --git a/crates/polars-plan/src/dsl/function_expr/range/mod.rs b/crates/polars-plan/src/dsl/function_expr/range/mod.rs index ab0508eff3de4..dfee18e7e7ccf 100644 --- a/crates/polars-plan/src/dsl/function_expr/range/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/range/mod.rs @@ -10,7 +10,6 @@ mod utils; use std::fmt::{Display, Formatter}; use polars_core::prelude::*; -use polars_core::series::Series; #[cfg(feature = "temporal")] use polars_time::{ClosedWindow, Duration}; #[cfg(feature = "serde")] diff --git a/crates/polars-plan/src/dsl/function_expr/range/time_range.rs b/crates/polars-plan/src/dsl/function_expr/range/time_range.rs index 2d799a471269a..991368356cc51 100644 --- a/crates/polars-plan/src/dsl/function_expr/range/time_range.rs +++ b/crates/polars-plan/src/dsl/function_expr/range/time_range.rs @@ -1,5 +1,4 @@ use polars_core::prelude::*; -use polars_core::series::Series; use polars_time::{time_range_impl, ClosedWindow, Duration}; use super::utils::{ diff --git a/crates/polars-plan/src/dsl/function_expr/rolling.rs b/crates/polars-plan/src/dsl/function_expr/rolling.rs index 272ef6f6ba131..67772ef31adf6 100644 --- a/crates/polars-plan/src/dsl/function_expr/rolling.rs +++ b/crates/polars-plan/src/dsl/function_expr/rolling.rs @@ -1,3 +1,5 @@ +use polars_time::chunkedarray::*; + use super::*; #[derive(Clone, PartialEq, Debug)] diff --git a/crates/polars-plan/src/dsl/function_expr/search_sorted.rs b/crates/polars-plan/src/dsl/function_expr/search_sorted.rs index 9d4566dfce65e..87933fc7bd6cd 100644 --- a/crates/polars-plan/src/dsl/function_expr/search_sorted.rs +++ b/crates/polars-plan/src/dsl/function_expr/search_sorted.rs @@ -1,5 +1,3 @@ -use polars_ops::prelude::search_sorted; - use super::*; pub(super) fn search_sorted_impl(s: &mut [Series], side: SearchSortedSide) -> PolarsResult { diff --git a/crates/polars-plan/src/dsl/function_expr/temporal.rs b/crates/polars-plan/src/dsl/function_expr/temporal.rs index c5660aac8f8c3..279f03883db70 100644 --- a/crates/polars-plan/src/dsl/function_expr/temporal.rs +++ b/crates/polars-plan/src/dsl/function_expr/temporal.rs @@ -154,9 +154,9 @@ pub(super) fn datetime( NaiveDate::from_ymd_opt(y, m, d) .and_then(|nd| nd.and_hms_micro_opt(h, mnt, s, us)) .map(|ndt| match time_unit { - TimeUnit::Milliseconds => ndt.timestamp_millis(), - TimeUnit::Microseconds => ndt.timestamp_micros(), - TimeUnit::Nanoseconds => ndt.timestamp_nanos_opt().unwrap(), + TimeUnit::Milliseconds => ndt.and_utc().timestamp_millis(), + TimeUnit::Microseconds => ndt.and_utc().timestamp_micros(), + TimeUnit::Nanoseconds => ndt.and_utc().timestamp_nanos_opt().unwrap(), }) } else { None diff --git a/crates/polars-plan/src/dsl/list.rs b/crates/polars-plan/src/dsl/list.rs index f97c3f4935ea7..603ec2553590f 100644 --- a/crates/polars-plan/src/dsl/list.rs +++ b/crates/polars-plan/src/dsl/list.rs @@ -5,7 +5,6 @@ use polars_core::prelude::*; #[cfg(feature = "diff")] use polars_core::series::ops::NullBehavior; -use crate::dsl::function_expr::FunctionExpr; use crate::prelude::function_expr::ListFunction; use crate::prelude::*; diff --git a/crates/polars-plan/src/dsl/meta.rs b/crates/polars-plan/src/dsl/meta.rs index 844d951204d86..28a554007a500 100644 --- a/crates/polars-plan/src/dsl/meta.rs +++ b/crates/polars-plan/src/dsl/meta.rs @@ -2,7 +2,6 @@ use std::fmt::Display; use std::ops::BitAnd; use super::*; -use crate::dsl::selector::Selector; use crate::logical_plan::projection::is_regex_projection; use crate::logical_plan::tree_format::TreeFmtVisitor; use crate::logical_plan::visitor::{AexprNode, TreeWalker}; diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index ed8ec878870e7..007c6939f7120 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -8,6 +8,8 @@ use std::any::Any; #[cfg(feature = "dtype-categorical")] pub use cat::*; +#[cfg(feature = "rolling_window")] +pub(crate) use polars_time::prelude::*; mod arithmetic; mod arity; #[cfg(feature = "dtype-array")] @@ -59,8 +61,6 @@ use polars_core::prelude::*; use polars_core::series::ops::NullBehavior; use polars_core::series::IsSorted; use polars_core::utils::try_get_supertype; -#[cfg(feature = "rolling_window")] -use polars_time::prelude::SeriesOpsTime; pub(crate) use selector::Selector; #[cfg(feature = "dtype-struct")] pub use struct_::*; @@ -69,9 +69,6 @@ pub use udf::UserDefinedFunction; use crate::constants::MAP_LIST_NAME; pub use crate::logical_plan::lit; use crate::prelude::*; -use crate::utils::has_expr; -#[cfg(feature = "is_in")] -use crate::utils::has_leaf_literal; impl Expr { /// Modify the Options passed to the `Function` node. @@ -740,7 +737,7 @@ impl Expr { }; self.function_with_options( - move |s: Series| Ok(Some(s.product())), + move |s: Series| Some(s.product()).transpose(), GetOutput::map_dtype(|dt| { use DataType::*; match dt { diff --git a/crates/polars-plan/src/dsl/python_udf.rs b/crates/polars-plan/src/dsl/python_udf.rs index b15e1fbe9b5d1..e1fa05d419f06 100644 --- a/crates/polars-plan/src/dsl/python_udf.rs +++ b/crates/polars-plan/src/dsl/python_udf.rs @@ -1,7 +1,6 @@ use std::io::Cursor; use std::sync::Arc; -use arrow::legacy::error::PolarsResult; use polars_core::datatypes::{DataType, Field}; use polars_core::error::*; use polars_core::frame::DataFrame; diff --git a/crates/polars-plan/src/dsl/string.rs b/crates/polars-plan/src/dsl/string.rs index 42a7cb2471fd1..88c43c4e5ff76 100644 --- a/crates/polars-plan/src/dsl/string.rs +++ b/crates/polars-plan/src/dsl/string.rs @@ -1,4 +1,3 @@ -use super::function_expr::StringFunction; use super::*; /// Specialized expressions for [`Series`] of [`DataType::String`]. pub struct StringNameSpace(pub(crate) Expr); diff --git a/crates/polars-plan/src/dsl/struct_.rs b/crates/polars-plan/src/dsl/struct_.rs index eb6c066dca3fd..db02ff0230453 100644 --- a/crates/polars-plan/src/dsl/struct_.rs +++ b/crates/polars-plan/src/dsl/struct_.rs @@ -1,5 +1,4 @@ use super::*; -use crate::dsl::function_expr::StructFunction; /// Specialized expressions for Struct dtypes. pub struct StructNameSpace(pub(crate) Expr); diff --git a/crates/polars-plan/src/frame/opt_state.rs b/crates/polars-plan/src/frame/opt_state.rs index 1415ffd66acab..ff3fe82061096 100644 --- a/crates/polars-plan/src/frame/opt_state.rs +++ b/crates/polars-plan/src/frame/opt_state.rs @@ -1,19 +1,34 @@ #[derive(Copy, Clone, Debug)] /// State of the allowed optimizations pub struct OptState { + /// Only read columns that are used later in the query. pub projection_pushdown: bool, + /// Apply predicates/filters as early as possible. pub predicate_pushdown: bool, + /// Run many type coercion optimization rules until fixed point. pub type_coercion: bool, + /// Run many expression optimization rules until fixed point. pub simplify_expr: bool, + /// Cache file reads. pub file_caching: bool, + /// Pushdown slices/limits. pub slice_pushdown: bool, #[cfg(feature = "cse")] + /// Run common-subplan-elimination. This elides duplicate plans and caches their + /// outputs. pub comm_subplan_elim: bool, #[cfg(feature = "cse")] + /// Run common-subexpression-elimination. This elides duplicate expressions and caches their + /// outputs. pub comm_subexpr_elim: bool, + /// Run nodes that are capably of doing so on the streaming engine. pub streaming: bool, + /// Run every node eagerly. This turns off multi-node optimizations. pub eager: bool, + /// Replace simple projections with a faster inlined projection that skips the expression engine. pub fast_projection: bool, + /// Try to estimate the number of rows so that joins can determine which side to keep in memory. + pub row_estimate: bool, } impl Default for OptState { @@ -33,6 +48,7 @@ impl Default for OptState { streaming: false, fast_projection: true, eager: false, + row_estimate: true, } } } diff --git a/crates/polars-plan/src/logical_plan/aexpr/mod.rs b/crates/polars-plan/src/logical_plan/aexpr/mod.rs index 706a589d42984..300ed2b9472bb 100644 --- a/crates/polars-plan/src/logical_plan/aexpr/mod.rs +++ b/crates/polars-plan/src/logical_plan/aexpr/mod.rs @@ -2,16 +2,12 @@ mod hash; mod schema; use std::hash::{Hash, Hasher}; -use std::sync::Arc; -use arrow::legacy::prelude::QuantileInterpolOptions; -use polars_core::frame::group_by::GroupByMethod; use polars_core::prelude::*; use polars_core::utils::{get_time_units, try_get_supertype}; use polars_utils::arena::{Arena, Node}; use strum_macros::IntoStaticStr; -use crate::dsl::function_expr::FunctionExpr; #[cfg(feature = "cse")] use crate::logical_plan::visitor::AexprNode; use crate::logical_plan::Context; diff --git a/crates/polars-plan/src/logical_plan/alp.rs b/crates/polars-plan/src/logical_plan/alp.rs index cdb45586e572a..547be5f05592c 100644 --- a/crates/polars-plan/src/logical_plan/alp.rs +++ b/crates/polars-plan/src/logical_plan/alp.rs @@ -1,18 +1,12 @@ use std::borrow::Cow; use std::path::PathBuf; -use std::sync::Arc; use polars_core::prelude::*; -use polars_utils::arena::{Arena, Node}; use polars_utils::idx_vec::UnitVec; use polars_utils::unitvec; use super::projection_expr::*; -use crate::logical_plan::functions::FunctionNode; -use crate::logical_plan::schema::FileInfo; -use crate::logical_plan::FileScan; use crate::prelude::*; -use crate::utils::PushNode; /// [`ALogicalPlan`] is a representation of [`LogicalPlan`] with [`Node`]s which are allocated in an [`Arena`] #[derive(Clone, Debug)] diff --git a/crates/polars-plan/src/logical_plan/builder.rs b/crates/polars-plan/src/logical_plan/builder.rs index 5f44f2deb6dd5..b3beecddfeadb 100644 --- a/crates/polars-plan/src/logical_plan/builder.rs +++ b/crates/polars-plan/src/logical_plan/builder.rs @@ -1,12 +1,9 @@ #[cfg(feature = "csv")] use std::io::{Read, Seek}; -use polars_core::frame::explode::MeltArgs; use polars_core::prelude::*; #[cfg(feature = "parquet")] use polars_io::cloud::CloudOptions; -#[cfg(feature = "ipc")] -use polars_io::ipc::IpcReader; #[cfg(all(feature = "parquet", feature = "async"))] use polars_io::parquet::ParquetAsyncReader; #[cfg(feature = "parquet")] @@ -31,9 +28,7 @@ use polars_io::{ use super::builder_functions::*; use crate::dsl::functions::horizontal::all_horizontal; -use crate::logical_plan::functions::FunctionNode; use crate::logical_plan::projection::{is_regex_projection, rewrite_projections}; -use crate::logical_plan::schema::{det_join_schema, FileInfo}; #[cfg(feature = "python")] use crate::prelude::python_udf::PythonFunction; use crate::prelude::*; @@ -242,38 +237,57 @@ impl LogicalPlanBuilder { cache: bool, row_index: Option, rechunk: bool, + #[cfg(feature = "cloud")] cloud_options: Option, ) -> PolarsResult { - use polars_io::SerReader as _; + use polars_io::is_cloud_url; let path = path.into(); - let file = polars_utils::open_file(&path)?; - let mut reader = IpcReader::new(file); - let reader_schema = reader.schema()?; - let mut schema: Schema = (&reader_schema).into(); - if let Some(rc) = &row_index { - let _ = schema.insert_at_index(0, rc.name.as_str().into(), IDX_DTYPE); - } - - let num_rows = reader._num_rows()?; - let file_info = FileInfo::new(Arc::new(schema), Some(reader_schema), (None, num_rows)); + let metadata = if is_cloud_url(&path) { + #[cfg(not(feature = "cloud"))] + panic!( + "One or more of the cloud storage features ('aws', 'gcp', ...) must be enabled." + ); - let file_options = FileScanOptions { - with_columns: None, - cache, - n_rows, - rechunk, - row_index, - file_counter: Default::default(), - // TODO! add - hive_partitioning: false, + #[cfg(feature = "cloud")] + { + let uri = path.to_string_lossy(); + get_runtime().block_on(async { + polars_io::ipc::IpcReaderAsync::from_uri(&uri, cloud_options.as_ref()) + .await? + .metadata() + .await + })? + } + } else { + arrow::io::ipc::read::read_file_metadata(&mut std::io::BufReader::new( + polars_utils::open_file(&path)?, + ))? }; + Ok(LogicalPlan::Scan { paths: Arc::new([path]), - file_info, - file_options, + file_info: FileInfo::new( + prepare_schema(metadata.schema.as_ref().into(), row_index.as_ref()), + Some(Arc::clone(&metadata.schema)), + (None, 0), + ), + file_options: FileScanOptions { + with_columns: None, + cache, + n_rows, + rechunk, + row_index, + file_counter: Default::default(), + hive_partitioning: false, + }, predicate: None, - scan_type: FileScan::Ipc { options }, + scan_type: FileScan::Ipc { + options, + #[cfg(feature = "cloud")] + cloud_options, + metadata: Some(metadata), + }, } .into()) } @@ -434,7 +448,7 @@ impl LogicalPlanBuilder { if columns.is_empty() { self.map( - |_| Ok(DataFrame::new_no_checks(vec![])), + |_| Ok(DataFrame::empty()), AllowedOptimizations::default(), Some(Arc::new(|_: &Schema| Ok(Arc::new(Schema::default())))), "EMPTY PROJECTION", @@ -459,7 +473,7 @@ impl LogicalPlanBuilder { if exprs.is_empty() { self.map( - |_| Ok(DataFrame::new_no_checks(vec![])), + |_| Ok(DataFrame::empty()), AllowedOptimizations::default(), Some(Arc::new(|_: &Schema| Ok(Arc::new(Schema::default())))), "EMPTY PROJECTION", diff --git a/crates/polars-plan/src/logical_plan/file_scan.rs b/crates/polars-plan/src/logical_plan/file_scan.rs index 2364711eef3b6..8f7319574c0c0 100644 --- a/crates/polars-plan/src/logical_plan/file_scan.rs +++ b/crates/polars-plan/src/logical_plan/file_scan.rs @@ -11,12 +11,18 @@ pub enum FileScan { #[cfg(feature = "parquet")] Parquet { options: ParquetOptions, - cloud_options: Option, + cloud_options: Option, #[cfg_attr(feature = "serde", serde(skip))] metadata: Option>, }, #[cfg(feature = "ipc")] - Ipc { options: IpcScanOptions }, + Ipc { + options: IpcScanOptions, + #[cfg(feature = "cloud")] + cloud_options: Option, + #[cfg_attr(feature = "serde", serde(skip))] + metadata: Option, + }, #[cfg_attr(feature = "serde", serde(skip))] Anonymous { options: Arc, @@ -43,7 +49,29 @@ impl PartialEq for FileScan { }, ) => opt_l == opt_r && c_l == c_r, #[cfg(feature = "ipc")] - (FileScan::Ipc { options: l }, FileScan::Ipc { options: r }) => l == r, + ( + FileScan::Ipc { + options: l, + #[cfg(feature = "cloud")] + cloud_options: c_l, + .. + }, + FileScan::Ipc { + options: r, + #[cfg(feature = "cloud")] + cloud_options: c_r, + .. + }, + ) => { + #[cfg(not(feature = "cloud"))] + { + l == r + } + #[cfg(feature = "cloud")] + { + l == r && c_l == c_r + } + }, _ => false, } } diff --git a/crates/polars-plan/src/logical_plan/format.rs b/crates/polars-plan/src/logical_plan/format.rs index 456aee8955fdf..91c384aa86bbd 100644 --- a/crates/polars-plan/src/logical_plan/format.rs +++ b/crates/polars-plan/src/logical_plan/format.rs @@ -202,7 +202,11 @@ impl LogicalPlan { input._format(f, sub_indent) }, Distinct { input, options } => { - write!(f, "{:indent$}UNIQUE BY {:?}", "", options.subset)?; + write!( + f, + "{:indent$}UNIQUE[maintain_order: {:?}, keep_strategy: {:?}] BY {:?}", + "", options.maintain_order, options.keep_strategy, options.subset + )?; input._format(f, sub_indent) }, Slice { input, offset, len } => { diff --git a/crates/polars-plan/src/logical_plan/functions/count.rs b/crates/polars-plan/src/logical_plan/functions/count.rs new file mode 100644 index 0000000000000..1da3484928fc2 --- /dev/null +++ b/crates/polars-plan/src/logical_plan/functions/count.rs @@ -0,0 +1,162 @@ +#[cfg(feature = "ipc")] +use arrow::io::ipc::read::get_row_count as count_rows_ipc_sync; +#[cfg(feature = "ipc")] +use polars_core::error::to_compute_err; +#[cfg(feature = "parquet")] +use polars_io::cloud::CloudOptions; +#[cfg(feature = "csv")] +use polars_io::csv::count_rows as count_rows_csv; +#[cfg(all(feature = "parquet", feature = "cloud"))] +use polars_io::parquet::ParquetAsyncReader; +#[cfg(feature = "parquet")] +use polars_io::parquet::ParquetReader; +#[cfg(all(feature = "parquet", feature = "async"))] +use polars_io::pl_async::{get_runtime, with_concurrency_budget}; +#[cfg(feature = "parquet")] +use polars_io::{is_cloud_url, SerReader}; + +use super::*; + +#[allow(unused_variables)] +pub fn count_rows(paths: &Arc<[PathBuf]>, scan_type: &FileScan) -> PolarsResult { + match scan_type { + #[cfg(feature = "csv")] + FileScan::Csv { options } => { + let n_rows: PolarsResult = paths + .iter() + .map(|path| { + count_rows_csv( + path, + options.separator, + options.quote_char, + options.comment_prefix.as_ref(), + options.eol_char, + options.has_header, + ) + }) + .sum(); + Ok(DataFrame::new(vec![Series::new("len", [n_rows? as IdxSize])]).unwrap()) + }, + #[cfg(feature = "parquet")] + FileScan::Parquet { cloud_options, .. } => { + let n_rows = count_rows_parquet(paths, cloud_options.as_ref())?; + Ok(DataFrame::new(vec![Series::new("len", [n_rows as IdxSize])]).unwrap()) + }, + #[cfg(feature = "ipc")] + FileScan::Ipc { + options, + #[cfg(feature = "cloud")] + cloud_options, + metadata, + } => { + let count: IdxSize = count_rows_ipc( + paths, + #[cfg(feature = "cloud")] + cloud_options.as_ref(), + metadata.as_ref(), + )? + .try_into() + .map_err(to_compute_err)?; + Ok(DataFrame::new(vec![Series::new("len", [count])]).unwrap()) + }, + FileScan::Anonymous { .. } => { + unreachable!(); + }, + } +} +#[cfg(feature = "parquet")] +pub(super) fn count_rows_parquet( + paths: &Arc<[PathBuf]>, + cloud_options: Option<&CloudOptions>, +) -> PolarsResult { + if paths.is_empty() { + return Ok(0); + }; + let is_cloud = is_cloud_url(paths.first().unwrap().as_path()); + + if is_cloud { + #[cfg(not(feature = "cloud"))] + panic!("One or more of the cloud storage features ('aws', 'gcp', ...) must be enabled."); + + #[cfg(feature = "cloud")] + { + get_runtime().block_on(count_rows_cloud_parquet(paths, cloud_options)) + } + } else { + paths + .iter() + .map(|path| { + let file = polars_utils::open_file(path)?; + let mut reader = ParquetReader::new(file); + reader.num_rows() + }) + .sum::>() + } +} + +#[cfg(all(feature = "parquet", feature = "async"))] +async fn count_rows_cloud_parquet( + paths: &Arc<[PathBuf]>, + cloud_options: Option<&CloudOptions>, +) -> PolarsResult { + let collection = paths.iter().map(|path| { + with_concurrency_budget(1, || async { + let mut reader = + ParquetAsyncReader::from_uri(&path.to_string_lossy(), cloud_options, None, None) + .await?; + reader.num_rows().await + }) + }); + futures::future::try_join_all(collection) + .await + .map(|rows| rows.iter().sum()) +} + +#[cfg(feature = "ipc")] +pub(super) fn count_rows_ipc( + paths: &Arc<[PathBuf]>, + #[cfg(feature = "cloud")] cloud_options: Option<&CloudOptions>, + metadata: Option<&arrow::io::ipc::read::FileMetadata>, +) -> PolarsResult { + if paths.is_empty() { + return Ok(0); + }; + let is_cloud = is_cloud_url(paths.first().unwrap().as_path()); + + if is_cloud { + #[cfg(not(feature = "cloud"))] + panic!("One or more of the cloud storage features ('aws', 'gcp', ...) must be enabled."); + + #[cfg(feature = "cloud")] + { + get_runtime().block_on(count_rows_cloud_ipc(paths, cloud_options, metadata)) + } + } else { + paths + .iter() + .map(|path| { + let mut reader = polars_utils::open_file(path)?; + count_rows_ipc_sync(&mut reader) + }) + .sum() + } +} + +#[cfg(all(feature = "ipc", feature = "async"))] +async fn count_rows_cloud_ipc( + paths: &Arc<[PathBuf]>, + cloud_options: Option<&CloudOptions>, + metadata: Option<&arrow::io::ipc::read::FileMetadata>, +) -> PolarsResult { + use polars_io::ipc::IpcReaderAsync; + + let collection = paths.iter().map(|path| { + with_concurrency_budget(1, || async { + let reader = IpcReaderAsync::from_uri(&path.to_string_lossy(), cloud_options).await?; + reader.count_rows(metadata).await + }) + }); + futures::future::try_join_all(collection) + .await + .map(|rows| rows.iter().sum()) +} diff --git a/crates/polars-plan/src/logical_plan/functions/merge_sorted.rs b/crates/polars-plan/src/logical_plan/functions/merge_sorted.rs index 1b486e589f0f2..a20a85d688120 100644 --- a/crates/polars-plan/src/logical_plan/functions/merge_sorted.rs +++ b/crates/polars-plan/src/logical_plan/functions/merge_sorted.rs @@ -29,8 +29,8 @@ pub(super) fn merge_sorted(df: &DataFrame, column: &str) -> PolarsResult, + scan_type: FileScan, + alias: Option>, + }, #[cfg_attr(feature = "serde", serde(skip))] Pipeline { function: Arc, @@ -111,6 +115,7 @@ impl PartialEq for FunctionNode { ) => l == r && dl == dr, (DropNulls { subset: l }, DropNulls { subset: r }) => l == r, (Rechunk, Rechunk) => true, + (Count { paths: paths_l, .. }, Count { paths: paths_r, .. }) => paths_l == paths_r, ( Rename { existing: existing_l, @@ -141,6 +146,7 @@ impl FunctionNode { MergeSorted { .. } => false, DropNulls { .. } | FastProjection { .. } + | Count { .. } | Unnest { .. } | Rename { .. } | Explode { .. } => true, @@ -193,6 +199,13 @@ impl FunctionNode { Ok(Cow::Owned(Arc::new(schema))) }, DropNulls { .. } => Ok(Cow::Borrowed(input_schema)), + Count { alias, .. } => { + let mut schema: Schema = Schema::with_capacity(1); + let name = + SmartString::from(alias.as_ref().map(|alias| alias.as_ref()).unwrap_or("len")); + schema.insert_at_index(0, name, IDX_DTYPE)?; + Ok(Cow::Owned(Arc::new(schema))) + }, Rechunk => Ok(Cow::Borrowed(input_schema)), Unnest { columns: _columns } => { #[cfg(feature = "dtype-struct")] @@ -254,7 +267,7 @@ impl FunctionNode { | Melt { .. } => true, #[cfg(feature = "merge_sorted")] MergeSorted { .. } => true, - RowIndex { .. } => false, + RowIndex { .. } | Count { .. } => false, Pipeline { .. } => unimplemented!(), } } @@ -268,6 +281,7 @@ impl FunctionNode { FastProjection { .. } | DropNulls { .. } | Rechunk + | Count { .. } | Unnest { .. } | Rename { .. } | Explode { .. } @@ -312,6 +326,9 @@ impl FunctionNode { } }, DropNulls { subset } => df.drop_nulls(Some(subset.as_ref())), + Count { + paths, scan_type, .. + } => count::count_rows(paths, scan_type), Rechunk => { df.as_single_chunk_par(); Ok(df) @@ -376,6 +393,7 @@ impl Display for FunctionNode { fmt_column_delimited(f, subset, "[", "]") }, Rechunk => write!(f, "RECHUNK"), + Count { .. } => write!(f, "FAST COUNT(*)"), Unnest { columns } => { write!(f, "UNNEST by:")?; let columns = columns.as_ref(); diff --git a/crates/polars-plan/src/logical_plan/lit.rs b/crates/polars-plan/src/logical_plan/lit.rs index 4965cd2c7d99d..e9e22281517c7 100644 --- a/crates/polars-plan/src/logical_plan/lit.rs +++ b/crates/polars-plan/src/logical_plan/lit.rs @@ -260,13 +260,13 @@ impl Literal for NaiveDateTime { fn lit(self) -> Expr { if in_nanoseconds_window(&self) { Expr::Literal(LiteralValue::DateTime( - self.timestamp_nanos_opt().unwrap(), + self.and_utc().timestamp_nanos_opt().unwrap(), TimeUnit::Nanoseconds, None, )) } else { Expr::Literal(LiteralValue::DateTime( - self.timestamp_micros(), + self.and_utc().timestamp_micros(), TimeUnit::Microseconds, None, )) diff --git a/crates/polars-plan/src/logical_plan/mod.rs b/crates/polars-plan/src/logical_plan/mod.rs index f378770456adb..8c023cf3935d5 100644 --- a/crates/polars-plan/src/logical_plan/mod.rs +++ b/crates/polars-plan/src/logical_plan/mod.rs @@ -3,8 +3,6 @@ use std::path::PathBuf; use std::sync::{Arc, Mutex}; use polars_core::prelude::*; -#[cfg(any(feature = "cloud", feature = "parquet"))] -use polars_io::cloud::CloudOptions; use crate::logical_plan::LogicalPlan::DataFrameScan; use crate::prelude::*; diff --git a/crates/polars-plan/src/logical_plan/optimizer/cache_states.rs b/crates/polars-plan/src/logical_plan/optimizer/cache_states.rs index 23881fdd4b3c3..a7be49238fe9f 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/cache_states.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/cache_states.rs @@ -1,5 +1,4 @@ use std::collections::BTreeMap; -use std::sync::Arc; use super::*; diff --git a/crates/polars-plan/src/logical_plan/optimizer/count_star.rs b/crates/polars-plan/src/logical_plan/optimizer/count_star.rs new file mode 100644 index 0000000000000..357e1bd608393 --- /dev/null +++ b/crates/polars-plan/src/logical_plan/optimizer/count_star.rs @@ -0,0 +1,133 @@ +use std::path::PathBuf; + +use super::*; + +pub(super) struct CountStar; + +impl CountStar { + pub(super) fn new() -> Self { + Self + } +} + +impl OptimizationRule for CountStar { + // Replace select count(*) from datasource with specialized map function. + fn optimize_plan( + &mut self, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + node: Node, + ) -> Option { + visit_logical_plan_for_scan_paths(node, lp_arena, expr_arena, false).map( + |count_star_expr| { + // MapFunction needs a leaf node, hence we create a dummy placeholder node + let placeholder = ALogicalPlan::DataFrameScan { + df: Arc::new(Default::default()), + schema: Arc::new(Default::default()), + output_schema: None, + projection: None, + selection: None, + }; + let placeholder_node = lp_arena.add(placeholder); + + let alp = ALogicalPlan::MapFunction { + input: placeholder_node, + function: FunctionNode::Count { + paths: count_star_expr.paths, + scan_type: count_star_expr.scan_type, + alias: count_star_expr.alias, + }, + }; + + lp_arena.replace(count_star_expr.node, alp.clone()); + alp + }, + ) + } +} + +struct CountStarExpr { + // Top node of the projection to replace + node: Node, + // Paths to the input files + paths: Arc<[PathBuf]>, + // File Type + scan_type: FileScan, + // Column Alias + alias: Option>, +} + +// Visit the logical plan and return CountStarExpr with the expr information gathered +// Return None if query is not a simple COUNT(*) FROM SOURCE +fn visit_logical_plan_for_scan_paths( + node: Node, + lp_arena: &Arena, + expr_arena: &Arena, + inside_union: bool, // Inside union's we do not check for COUNT(*) expression +) -> Option { + match lp_arena.get(node) { + ALogicalPlan::Union { inputs, .. } => { + let mut scan_type: Option = None; + let mut paths = Vec::with_capacity(inputs.len()); + for input in inputs { + match visit_logical_plan_for_scan_paths(*input, lp_arena, expr_arena, true) { + Some(expr) => { + paths.extend(expr.paths.iter().cloned()); + match &scan_type { + None => scan_type = Some(expr.scan_type), + Some(scan_type) => { + // All scans must be of the same type (e.g. csv / parquet) + if std::mem::discriminant(scan_type) + != std::mem::discriminant(&expr.scan_type) + { + return None; + } + }, + }; + }, + None => return None, + } + } + Some(CountStarExpr { + paths: paths.into(), + scan_type: scan_type.unwrap(), + node, + alias: None, + }) + }, + ALogicalPlan::Scan { + scan_type, paths, .. + } if !matches!(scan_type, FileScan::Anonymous { .. }) => Some(CountStarExpr { + paths: paths.clone(), + scan_type: scan_type.clone(), + node, + alias: None, + }), + ALogicalPlan::Projection { input, expr, .. } => { + if expr.len() == 1 { + let (valid, alias) = is_valid_count_expr(expr[0], expr_arena); + if valid || inside_union { + return visit_logical_plan_for_scan_paths(*input, lp_arena, expr_arena, false) + .map(|mut expr| { + expr.alias = alias; + expr.node = node; + expr + }); + } + } + None + }, + _ => None, + } +} + +fn is_valid_count_expr(node: Node, expr_arena: &Arena) -> (bool, Option>) { + match expr_arena.get(node) { + AExpr::Alias(node, alias) => { + let (valid, _) = is_valid_count_expr(*node, expr_arena); + (valid, Some(alias.clone())) + }, + AExpr::Len => (true, None), + _ => (false, None), + } +} diff --git a/crates/polars-plan/src/logical_plan/optimizer/cse.rs b/crates/polars-plan/src/logical_plan/optimizer/cse.rs index da3ac4ada8abd..0af97b7732425 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/cse.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/cse.rs @@ -265,7 +265,7 @@ pub(crate) fn elim_cmn_subplans( } let trails = trails.into_values().collect::>(); - // search from the leafs upwards and find the longest shared subplans + // search from the leaf nodes upwards and find the longest shared subplans let mut trail_ends = vec![]; // if i matches j // we don't need to search with j as they are equal diff --git a/crates/polars-plan/src/logical_plan/optimizer/cse_expr.rs b/crates/polars-plan/src/logical_plan/optimizer/cse_expr.rs index e7f4f7b6d1047..d0cd86b2a966e 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/cse_expr.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/cse_expr.rs @@ -3,8 +3,7 @@ use polars_utils::vec::CapacityByFactor; use super::*; use crate::constants::CSE_REPLACED; use crate::logical_plan::projection_expr::ProjectionExprs; -use crate::logical_plan::visitor::{RewriteRecursion, VisitRecursion}; -use crate::prelude::visitor::{ALogicalPlanNode, AexprNode, RewritingVisitor, TreeWalker, Visitor}; +use crate::prelude::visitor::AexprNode; // We use hashes to get an Identifier // but this is very hard to debug, so we also have a version that diff --git a/crates/polars-plan/src/logical_plan/optimizer/delay_rechunk.rs b/crates/polars-plan/src/logical_plan/optimizer/delay_rechunk.rs index 6c3705ca93ee5..7fa81a4085e93 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/delay_rechunk.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/delay_rechunk.rs @@ -1,7 +1,5 @@ use std::collections::BTreeSet; -use polars_utils::arena::{Arena, Node}; - use super::*; #[derive(Default)] diff --git a/crates/polars-plan/src/logical_plan/optimizer/drop_nulls.rs b/crates/polars-plan/src/logical_plan/optimizer/drop_nulls.rs index 8683d2716a03c..937750f94024a 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/drop_nulls.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/drop_nulls.rs @@ -1,10 +1,5 @@ -use std::sync::Arc; - use super::*; -use crate::dsl::function_expr::FunctionExpr; -use crate::logical_plan::functions::FunctionNode; use crate::logical_plan::iterator::*; -use crate::utils::aexpr_to_leaf_names; /// If we realize that a predicate drops nulls on a subset /// we replace it with an explicit df.drop_nulls call, as this diff --git a/crates/polars-plan/src/logical_plan/optimizer/fast_projection.rs b/crates/polars-plan/src/logical_plan/optimizer/fast_projection.rs index 14ba5f29ebbc6..6f8e7f27a453a 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/fast_projection.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/fast_projection.rs @@ -4,8 +4,6 @@ use polars_core::prelude::*; use smartstring::SmartString; use super::*; -use crate::logical_plan::alp::ALogicalPlan; -use crate::logical_plan::functions::FunctionNode; /// Projection in the physical plan is done by selecting an expression per thread. /// In case of many projections and columns this can be expensive when the expressions are simple diff --git a/crates/polars-plan/src/logical_plan/optimizer/file_caching.rs b/crates/polars-plan/src/logical_plan/optimizer/file_caching.rs index 23791d3dd6b0d..d3377a1452294 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/file_caching.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/file_caching.rs @@ -1,10 +1,7 @@ use std::path::PathBuf; -use std::sync::Arc; -use polars_core::datatypes::PlHashMap; use polars_core::prelude::*; -use crate::logical_plan::ALogicalPlanBuilder; use crate::prelude::*; #[derive(Hash, Eq, PartialEq, Clone, Debug)] diff --git a/crates/polars-plan/src/logical_plan/optimizer/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/mod.rs index 44448a61c334f..226e1636c6da5 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/mod.rs @@ -1,4 +1,3 @@ -use polars_core::datatypes::PlHashMap; use polars_core::prelude::*; use crate::prelude::*; @@ -10,6 +9,7 @@ mod delay_rechunk; mod drop_nulls; mod collect_members; +mod count_star; #[cfg(feature = "cse")] mod cse_expr; mod fast_projection; @@ -36,8 +36,6 @@ mod type_coercion; use delay_rechunk::DelayRechunk; use drop_nulls::ReplaceDropNulls; use fast_projection::FastProjectionAndCollapse; -#[cfg(any(feature = "ipc", feature = "parquet", feature = "csv"))] -use file_caching::{find_column_union_and_fingerprints, FileCacher}; use polars_io::predicates::PhysicalIoExpr; pub use predicate_pushdown::PredicatePushDown; pub use projection_pushdown::ProjectionPushDown; @@ -48,6 +46,7 @@ pub use type_coercion::TypeCoercionRule; use self::flatten_union::FlattenUnionRule; pub use crate::frame::{AllowedOptimizations, OptState}; +use crate::logical_plan::optimizer::count_star::CountStar; #[cfg(feature = "cse")] use crate::logical_plan::optimizer::cse_expr::CommonSubExprOptimizer; use crate::logical_plan::optimizer::predicate_pushdown::HiveEval; @@ -141,6 +140,11 @@ pub fn optimize( if members.has_joins_or_unions && members.has_cache { cache_states::set_cache_states(lp_top, lp_arena, expr_arena, scratch, cse_plan_changed); } + + if projection_pushdown_opt.is_count_star { + let mut count_star_opt = CountStar::new(); + count_star_opt.optimize_plan(lp_arena, expr_arena, lp_top); + } } if predicate_pushdown { diff --git a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/join.rs b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/join.rs index 5db230fffd676..0e32d298d70cf 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/join.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/join.rs @@ -1,10 +1,17 @@ use super::*; -// information concerning individual sides of a join +// Information concerning individual sides of a join. #[derive(PartialEq, Eq)] struct LeftRight(T, T); -fn should_block_join_specific(ae: &AExpr, how: &JoinType) -> LeftRight { +fn should_block_join_specific( + ae: &AExpr, + how: &JoinType, + on_names: &PlHashSet>, + expr_arena: &Arena, + schema_left: &Schema, + schema_right: &Schema, +) -> LeftRight { use AExpr::*; match ae { // joins can produce null values @@ -36,12 +43,23 @@ fn should_block_join_specific(ae: &AExpr, how: &JoinType) -> LeftRight { // any operation that checks for equality or ordering can be wrong because // the join can produce null values // TODO! check if we can be less conservative here - BinaryExpr { op, .. } => { - if matches!(op, Operator::NotEq) { - LeftRight(false, false) - } else { - join_produces_null(how) - } + BinaryExpr { op, left, right } => match op { + Operator::NotEq => LeftRight(false, false), + Operator::Eq => { + let LeftRight(bleft, bright) = join_produces_null(how); + + let l_name = aexpr_output_name(*left, expr_arena).unwrap(); + let r_name = aexpr_output_name(*right, expr_arena).unwrap(); + + let is_in_on = on_names.contains(&l_name) || on_names.contains(&r_name); + + let block_left = + is_in_on && (schema_left.contains(&l_name) || schema_left.contains(&r_name)); + let block_right = + is_in_on && (schema_right.contains(&l_name) || schema_right.contains(&r_name)); + LeftRight(block_left | bleft, block_right | bright) + }, + _ => join_produces_null(how), }, _ => LeftRight(false, false), } @@ -98,6 +116,16 @@ pub(super) fn process_join( let schema_left = lp_arena.get(input_left).schema(lp_arena); let schema_right = lp_arena.get(input_right).schema(lp_arena); + let on_names = left_on + .iter() + .flat_map(|n| aexpr_to_leaf_names_iter(*n, expr_arena)) + .chain( + right_on + .iter() + .flat_map(|n| aexpr_to_leaf_names_iter(*n, expr_arena)), + ) + .collect::>(); + let mut pushdown_left = init_hashmap(Some(acc_predicates.len())); let mut pushdown_right = init_hashmap(Some(acc_predicates.len())); let mut local_predicates = Vec::with_capacity(acc_predicates.len()); @@ -105,10 +133,26 @@ pub(super) fn process_join( for (_, predicate) in acc_predicates { // check if predicate can pass the joins node let block_pushdown_left = has_aexpr(predicate, expr_arena, |ae| { - should_block_join_specific(ae, &options.args.how).0 + should_block_join_specific( + ae, + &options.args.how, + &on_names, + expr_arena, + &schema_left, + &schema_right, + ) + .0 }); let block_pushdown_right = has_aexpr(predicate, expr_arena, |ae| { - should_block_join_specific(ae, &options.args.how).1 + should_block_join_specific( + ae, + &options.args.how, + &on_names, + expr_arena, + &schema_left, + &schema_right, + ) + .1 }); // these indicate to which tables we are going to push down the predicate diff --git a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs index c9f08519ffb43..b3c83ca70ae64 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs @@ -11,7 +11,6 @@ use utils::*; use super::*; use crate::dsl::function_expr::FunctionExpr; -use crate::logical_plan::optimizer; use crate::prelude::optimizer::predicate_pushdown::group_by::process_group_by; use crate::prelude::optimizer::predicate_pushdown::join::process_join; use crate::prelude::optimizer::predicate_pushdown::rename::process_rename; @@ -404,24 +403,15 @@ impl<'a> PredicatePushDown<'a> { input, options } => { - - if matches!(options.keep_strategy, UniqueKeepStrategy::Any | UniqueKeepStrategy::None) { - // currently the distinct operation only keeps the first occurrences. - // this may have influence on the pushed down predicates. If the pushed down predicates - // contain a binary expression (thus depending on values in multiple columns) - // the final result may differ if it is pushed down. - - let mut root_count = 0; - - // if this condition is called more than once, its a binary or ternary operation. - let condition = |_| { - if root_count == 0 { - root_count += 1; - false - } else { - true - } + if let Some(ref subset) = options.subset { + // Predicates on the subset can pass. + let subset = subset.clone(); + let mut names_set = PlHashSet::<&str>::with_capacity(subset.len()); + for name in subset.iter() { + names_set.insert(name.as_str()); }; + + let condition = |name: Arc| !names_set.contains(name.as_ref()); let local_predicates = transfer_to_local_by_name(expr_arena, &mut acc_predicates, condition); diff --git a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/utils.rs b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/utils.rs index 3e734f115d76e..49885b8e0b61a 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/utils.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/utils.rs @@ -1,10 +1,7 @@ -use polars_core::datatypes::PlHashMap; use polars_core::prelude::*; use super::keys::*; -use crate::logical_plan::Context; use crate::prelude::*; -use crate::utils::{aexpr_to_leaf_names, has_aexpr}; trait Dsl { fn and(self, right: Node, arena: &mut Arena) -> Node; @@ -115,8 +112,9 @@ pub(super) fn predicate_is_sort_boundary(node: Node, expr_arena: &Arena) has_aexpr(node, expr_arena, matches) } -/// Transfer a predicate from `acc_predicates` that will be pushed down -/// to a local_predicates vec based on a condition. +/// Evaluates a condition on the column name inputs of every predicate, where if +/// the condition evaluates to true on any column name the predicate is +/// transferred to local. pub(super) fn transfer_to_local_by_name( expr_arena: &Arena, acc_predicates: &mut PlHashMap, Node>, @@ -132,7 +130,7 @@ where for name in root_names { if condition(name) { remove_keys.push(key.clone()); - continue; + break; } } } diff --git a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs index 68e032ebd170b..e10a13eb261fe 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs @@ -16,7 +16,6 @@ use polars_io::RowIndex; use semi_anti_join::process_semi_anti_join; use crate::logical_plan::Context; -use crate::prelude::iterator::ArenaExprIter; use crate::prelude::optimizer::projection_pushdown::generic::process_generic; use crate::prelude::optimizer::projection_pushdown::group_by::process_group_by; use crate::prelude::optimizer::projection_pushdown::hconcat::process_hconcat; @@ -26,8 +25,7 @@ use crate::prelude::optimizer::projection_pushdown::projection::process_projecti use crate::prelude::optimizer::projection_pushdown::rename::process_rename; use crate::prelude::*; use crate::utils::{ - aexpr_assign_renamed_leaf, aexpr_to_column_nodes, aexpr_to_leaf_names, check_input_node, - expr_is_projected_upstream, + aexpr_assign_renamed_leaf, aexpr_to_leaf_names, check_input_node, expr_is_projected_upstream, }; fn init_vec() -> Vec { @@ -151,11 +149,15 @@ fn update_scan_schema( Ok(new_schema) } -pub struct ProjectionPushDown {} +pub struct ProjectionPushDown { + pub is_count_star: bool, +} impl ProjectionPushDown { pub(super) fn new() -> Self { - Self {} + Self { + is_count_star: false, + } } /// Projection will be done at this node, but we continue optimization diff --git a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/projection.rs b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/projection.rs index 20edf7dbe13ac..04fb96b11c1b8 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/projection.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/projection.rs @@ -84,6 +84,7 @@ pub(super) fn process_projection( } add_expr_to_accumulated(expr, &mut acc_projections, &mut projected_names, expr_arena); local_projection.push(exprs[0]); + proj_pd.is_count_star = true; } else { // A projection can consist of a chain of expressions followed by an alias. // We want to do the chain locally because it can have complicated side effects. diff --git a/crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs b/crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs index b3d3440adc5e5..c2dd29a8b4910 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs @@ -1,67 +1,64 @@ -use polars_utils::arena::Arena; +use polars_utils::floor_divmod::FloorDivMod; +use polars_utils::total_ord::ToTotalOrd; -#[cfg(all(feature = "strings", feature = "concat_str"))] -use crate::dsl::function_expr::StringFunction; -use crate::logical_plan::optimizer::stack_opt::OptimizationRule; use crate::logical_plan::*; use crate::prelude::optimizer::simplify_functions::optimize_functions; macro_rules! eval_binary_same_type { - ($lhs:expr, $operand: tt, $rhs:expr) => {{ - if let (AExpr::Literal(lit_left), AExpr::Literal(lit_right)) = ($lhs, $rhs) { - match (lit_left, lit_right) { - (LiteralValue::Float32(x), LiteralValue::Float32(y)) => { - Some(AExpr::Literal(LiteralValue::Float32(x $operand y))) - } - (LiteralValue::Float64(x), LiteralValue::Float64(y)) => { - Some(AExpr::Literal(LiteralValue::Float64(x $operand y))) - } - #[cfg(feature = "dtype-i8")] - (LiteralValue::Int8(x), LiteralValue::Int8(y)) => { - Some(AExpr::Literal(LiteralValue::Int8(x $operand y))) - } - #[cfg(feature = "dtype-i16")] - (LiteralValue::Int16(x), LiteralValue::Int16(y)) => { - Some(AExpr::Literal(LiteralValue::Int16(x $operand y))) - } - (LiteralValue::Int32(x), LiteralValue::Int32(y)) => { - Some(AExpr::Literal(LiteralValue::Int32(x $operand y))) - } - (LiteralValue::Int64(x), LiteralValue::Int64(y)) => { - Some(AExpr::Literal(LiteralValue::Int64(x $operand y))) - } - #[cfg(feature = "dtype-u8")] - (LiteralValue::UInt8(x), LiteralValue::UInt8(y)) => { - Some(AExpr::Literal(LiteralValue::UInt8(x $operand y))) - } - #[cfg(feature = "dtype-u16")] - (LiteralValue::UInt16(x), LiteralValue::UInt16(y)) => { - Some(AExpr::Literal(LiteralValue::UInt16(x $operand y))) - } - (LiteralValue::UInt32(x), LiteralValue::UInt32(y)) => { - Some(AExpr::Literal(LiteralValue::UInt32(x $operand y))) - } - (LiteralValue::UInt64(x), LiteralValue::UInt64(y)) => { - Some(AExpr::Literal(LiteralValue::UInt64(x $operand y))) + ($lhs:expr, $rhs:expr, |$l: ident, $r: ident| $ret: expr) => {{ + if let (AExpr::Literal(lit_left), AExpr::Literal(lit_right)) = ($lhs, $rhs) { + match (lit_left, lit_right) { + (LiteralValue::Float32($l), LiteralValue::Float32($r)) => { + Some(AExpr::Literal(LiteralValue::Float32($ret))) + }, + (LiteralValue::Float64($l), LiteralValue::Float64($r)) => { + Some(AExpr::Literal(LiteralValue::Float64($ret))) + }, + #[cfg(feature = "dtype-i8")] + (LiteralValue::Int8($l), LiteralValue::Int8($r)) => { + Some(AExpr::Literal(LiteralValue::Int8($ret))) + }, + #[cfg(feature = "dtype-i16")] + (LiteralValue::Int16($l), LiteralValue::Int16($r)) => { + Some(AExpr::Literal(LiteralValue::Int16($ret))) + }, + (LiteralValue::Int32($l), LiteralValue::Int32($r)) => { + Some(AExpr::Literal(LiteralValue::Int32($ret))) + }, + (LiteralValue::Int64($l), LiteralValue::Int64($r)) => { + Some(AExpr::Literal(LiteralValue::Int64($ret))) + }, + #[cfg(feature = "dtype-u8")] + (LiteralValue::UInt8($l), LiteralValue::UInt8($r)) => { + Some(AExpr::Literal(LiteralValue::UInt8($ret))) + }, + #[cfg(feature = "dtype-u16")] + (LiteralValue::UInt16($l), LiteralValue::UInt16($r)) => { + Some(AExpr::Literal(LiteralValue::UInt16($ret))) + }, + (LiteralValue::UInt32($l), LiteralValue::UInt32($r)) => { + Some(AExpr::Literal(LiteralValue::UInt32($ret))) + }, + (LiteralValue::UInt64($l), LiteralValue::UInt64($r)) => { + Some(AExpr::Literal(LiteralValue::UInt64($ret))) + }, + _ => None, } - _ => None, + } else { + None } - } else { - None - } - - }} + }}; } -macro_rules! eval_binary_bool_type { +macro_rules! eval_binary_cmp_same_type { ($lhs:expr, $operand: tt, $rhs:expr) => {{ if let (AExpr::Literal(lit_left), AExpr::Literal(lit_right)) = ($lhs, $rhs) { match (lit_left, lit_right) { (LiteralValue::Float32(x), LiteralValue::Float32(y)) => { - Some(AExpr::Literal(LiteralValue::Boolean(x $operand y))) + Some(AExpr::Literal(LiteralValue::Boolean(x.to_total_ord() $operand y.to_total_ord()))) } (LiteralValue::Float64(x), LiteralValue::Float64(y)) => { - Some(AExpr::Literal(LiteralValue::Boolean(x $operand y))) + Some(AExpr::Literal(LiteralValue::Boolean(x.to_total_ord() $operand y.to_total_ord()))) } #[cfg(feature = "dtype-i8")] (LiteralValue::Int8(x), LiteralValue::Int8(y)) => { @@ -303,12 +300,8 @@ fn string_addition_to_linear_concat( let schema = lp_arena.get(input).schema(lp_arena); let get_type = |ae: &AExpr| ae.get_type(&schema, Context::Default, expr_arena).ok(); - let type_a = get_type(left_aexpr) - .or_else(|| get_type(right_aexpr)) - .unwrap(); - let type_b = get_type(right_aexpr) - .or_else(|| get_type(right_aexpr)) - .unwrap(); + let type_a = get_type(left_aexpr).or_else(|| get_type(right_aexpr))?; + let type_b = get_type(right_aexpr).or_else(|| get_type(right_aexpr))?; if type_a != type_b { return None; @@ -447,7 +440,7 @@ impl OptimizationRule for SimplifyExprRule { #[allow(clippy::manual_map)] let out = match op { Plus => { - match eval_binary_same_type!(left_aexpr, +, right_aexpr) { + match eval_binary_same_type!(left_aexpr, right_aexpr, |l, r| l + r) { Some(new) => Some(new), None => { // try to replace addition of string columns with `concat_str` @@ -470,9 +463,61 @@ impl OptimizationRule for SimplifyExprRule { }, } }, - Minus => eval_binary_same_type!(left_aexpr, -, right_aexpr), - Multiply => eval_binary_same_type!(left_aexpr, *, right_aexpr), - Divide => eval_binary_same_type!(left_aexpr, /, right_aexpr), + Minus => eval_binary_same_type!(left_aexpr, right_aexpr, |l, r| l - r), + Multiply => eval_binary_same_type!(left_aexpr, right_aexpr, |l, r| l * r), + Divide => { + if let (AExpr::Literal(lit_left), AExpr::Literal(lit_right)) = + (left_aexpr, right_aexpr) + { + match (lit_left, lit_right) { + (LiteralValue::Float32(x), LiteralValue::Float32(y)) => { + Some(AExpr::Literal(LiteralValue::Float32(x / y))) + }, + (LiteralValue::Float64(x), LiteralValue::Float64(y)) => { + Some(AExpr::Literal(LiteralValue::Float64(x / y))) + }, + #[cfg(feature = "dtype-i8")] + (LiteralValue::Int8(x), LiteralValue::Int8(y)) => { + Some(AExpr::Literal(LiteralValue::Int8( + x.wrapping_floor_div_mod(*y).0, + ))) + }, + #[cfg(feature = "dtype-i16")] + (LiteralValue::Int16(x), LiteralValue::Int16(y)) => { + Some(AExpr::Literal(LiteralValue::Int16( + x.wrapping_floor_div_mod(*y).0, + ))) + }, + (LiteralValue::Int32(x), LiteralValue::Int32(y)) => { + Some(AExpr::Literal(LiteralValue::Int32( + x.wrapping_floor_div_mod(*y).0, + ))) + }, + (LiteralValue::Int64(x), LiteralValue::Int64(y)) => { + Some(AExpr::Literal(LiteralValue::Int64( + x.wrapping_floor_div_mod(*y).0, + ))) + }, + #[cfg(feature = "dtype-u8")] + (LiteralValue::UInt8(x), LiteralValue::UInt8(y)) => { + Some(AExpr::Literal(LiteralValue::UInt8(x / y))) + }, + #[cfg(feature = "dtype-u16")] + (LiteralValue::UInt16(x), LiteralValue::UInt16(y)) => { + Some(AExpr::Literal(LiteralValue::UInt16(x / y))) + }, + (LiteralValue::UInt32(x), LiteralValue::UInt32(y)) => { + Some(AExpr::Literal(LiteralValue::UInt32(x / y))) + }, + (LiteralValue::UInt64(x), LiteralValue::UInt64(y)) => { + Some(AExpr::Literal(LiteralValue::UInt64(x / y))) + }, + _ => None, + } + } else { + None + } + }, TrueDivide => { if let (AExpr::Literal(lit_left), AExpr::Literal(lit_right)) = (left_aexpr, right_aexpr) @@ -518,17 +563,23 @@ impl OptimizationRule for SimplifyExprRule { None } }, - Modulus => eval_binary_same_type!(left_aexpr, %, right_aexpr), - Lt => eval_binary_bool_type!(left_aexpr, <, right_aexpr), - Gt => eval_binary_bool_type!(left_aexpr, >, right_aexpr), - Eq | EqValidity => eval_binary_bool_type!(left_aexpr, ==, right_aexpr), - NotEq | NotEqValidity => eval_binary_bool_type!(left_aexpr, !=, right_aexpr), - GtEq => eval_binary_bool_type!(left_aexpr, >=, right_aexpr), - LtEq => eval_binary_bool_type!(left_aexpr, <=, right_aexpr), + Modulus => eval_binary_same_type!(left_aexpr, right_aexpr, |l, r| l + .wrapping_floor_div_mod(*r) + .1), + Lt => eval_binary_cmp_same_type!(left_aexpr, <, right_aexpr), + Gt => eval_binary_cmp_same_type!(left_aexpr, >, right_aexpr), + Eq | EqValidity => eval_binary_cmp_same_type!(left_aexpr, ==, right_aexpr), + NotEq | NotEqValidity => { + eval_binary_cmp_same_type!(left_aexpr, !=, right_aexpr) + }, + GtEq => eval_binary_cmp_same_type!(left_aexpr, >=, right_aexpr), + LtEq => eval_binary_cmp_same_type!(left_aexpr, <=, right_aexpr), And | LogicalAnd => eval_bitwise(left_aexpr, right_aexpr, |l, r| l & r), Or | LogicalOr => eval_bitwise(left_aexpr, right_aexpr, |l, r| l | r), Xor => eval_bitwise(left_aexpr, right_aexpr, |l, r| l ^ r), - FloorDivide => None, + FloorDivide => eval_binary_same_type!(left_aexpr, right_aexpr, |l, r| l + .wrapping_floor_div_mod(*r) + .0), }; if out.is_some() { return Ok(out); diff --git a/crates/polars-plan/src/logical_plan/optimizer/simplify_functions.rs b/crates/polars-plan/src/logical_plan/optimizer/simplify_functions.rs index a3d91acd2187f..70d7d57e413ab 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/simplify_functions.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/simplify_functions.rs @@ -1,6 +1,3 @@ -#[cfg(feature = "is_between")] -use polars_ops::series::ClosedInterval; - use super::*; pub(super) fn optimize_functions( diff --git a/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs b/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs index 74d55baeb0681..3560ec2df01e4 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs @@ -1,5 +1,6 @@ use polars_core::prelude::*; +use crate::logical_plan::projection_expr::ProjectionExprs; use crate::prelude::*; pub(super) struct SlicePushDown { @@ -13,6 +14,49 @@ struct State { len: IdxSize, } +/// Can push down slice when: +/// * all projections are elementwise +/// * at least 1 projection is based on a column (for height broadcast) +/// * projections not based on any column project as scalars +/// +/// Returns (all_elementwise, all_elementwise_and_any_expr_has_column) +fn can_pushdown_slice_past_projections( + exprs: &ProjectionExprs, + arena: &Arena, +) -> (bool, bool) { + let mut all_elementwise_and_any_expr_has_column = false; + for node in exprs.iter() { + // `select(c = Literal([1, 2, 3])).slice(0, 0)` must block slice pushdown, + // because `c` projects to a height independent from the input height. We check + // this by observing that `c` does not have any columns in its input notes. + // + // TODO: Simply checking that a column node is present does not handle e.g.: + // `select(c = Literal([1, 2, 3]).is_in(col(a)))`, for functions like `is_in`, + // `str.contains`, `str.contains_many` etc. - observe a column node is present + // but the output height is not dependent on it. + let mut has_column = false; + let mut literals_all_scalar = true; + let is_elementwise = arena.iter(*node).all(|(_node, ae)| { + has_column |= matches!(ae, AExpr::Column(_)); + literals_all_scalar &= if let AExpr::Literal(v) = ae { + v.projects_as_scalar() + } else { + true + }; + single_aexpr_is_elementwise(ae) + }); + + // If there is no column then all literals must be scalar + if !is_elementwise || !(has_column || literals_all_scalar) { + return (false, false); + } + + all_elementwise_and_any_expr_has_column |= has_column + } + + (true, all_elementwise_and_any_expr_has_column) +} + impl SlicePushDown { pub(super) fn new(streaming: bool) -> Self { Self { @@ -322,10 +366,7 @@ impl SlicePushDown { } // there is state, inspect the projection to determine how to deal with it (Projection {input, expr, schema, options}, Some(_)) => { - // The slice operation may only pass on simple projections. col("foo").alias("bar") - if expr.iter().all(|root| { - aexpr_is_elementwise(*root, expr_arena) - }) { + if can_pushdown_slice_past_projections(&expr, expr_arena).1 { let lp = Projection {input, expr, schema, options}; self.pushdown_and_continue(lp, state, lp_arena, expr_arena) } @@ -335,12 +376,16 @@ impl SlicePushDown { self.no_pushdown_restart_opt(lp, state, lp_arena, expr_arena) } } - // this is copied from `Projection` (HStack {input, exprs, schema, options}, _) => { - // The slice operation may only pass on simple projections. col("foo").alias("bar") - if exprs.iter().all(|root| { - aexpr_is_elementwise(*root, expr_arena) - }) { + let check = can_pushdown_slice_past_projections(&exprs, expr_arena); + + if ( + // If the schema length is greater then an input column is being projected, so + // the exprs in with_columns do not need to have an input column name. + schema.len() > exprs.len() && check.0 + ) + || check.1 // e.g. select(c).with_columns(c = c + 1) + { let lp = HStack {input, exprs, schema, options}; self.pushdown_and_continue(lp, state, lp_arena, expr_arena) } diff --git a/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs index 003e88672879c..6d0337ade15cc 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs @@ -8,10 +8,7 @@ use polars_utils::idx_vec::UnitVec; use polars_utils::unitvec; use super::*; -use crate::dsl::function_expr::FunctionExpr; use crate::logical_plan::optimizer::type_coercion::binary::process_binary; -use crate::logical_plan::Context; -use crate::utils::is_scan; pub struct TypeCoercionRule {} @@ -316,7 +313,7 @@ impl OptimizationRule for TypeCoercionRule { expr_arena.add(AExpr::Cast { expr: truthy_node, data_type: st.clone(), - strict: false, + strict: true, }) } else { truthy_node @@ -326,7 +323,7 @@ impl OptimizationRule for TypeCoercionRule { expr_arena.add(AExpr::Cast { expr: falsy_node, data_type: st, - strict: false, + strict: true, }) } else { falsy_node @@ -372,6 +369,10 @@ impl OptimizationRule for TypeCoercionRule { (DataType::Categorical(_, _) | DataType::Enum(_, _), DataType::String) => { return Ok(None) }, + #[cfg(feature = "dtype-categorical")] + (DataType::String, DataType::Categorical(_, _) | DataType::Enum(_, _)) => { + return Ok(None) + }, #[cfg(feature = "dtype-decimal")] (DataType::Decimal(_, _), _) | (_, DataType::Decimal(_, _)) => { polars_bail!(InvalidOperation: "`is_in` cannot check for {:?} values in {:?} data", &type_other, &type_left) diff --git a/crates/polars-plan/src/logical_plan/projection.rs b/crates/polars-plan/src/logical_plan/projection.rs index 9c8968b7a2f75..5d7924ed027f4 100644 --- a/crates/polars-plan/src/logical_plan/projection.rs +++ b/crates/polars-plan/src/logical_plan/projection.rs @@ -1,10 +1,7 @@ //! this contains code used for rewriting projections, expanding wildcards, regex selection etc. -use arrow::legacy::index::IndexToUsize; use polars_core::utils::get_supertype; use super::*; -use crate::prelude::function_expr::FunctionExpr; -use crate::utils::expr_output_name; /// This replace the wildcard Expr with a Column Expr. It also removes the Exclude Expr from the /// expression chain. diff --git a/crates/polars-plan/src/logical_plan/pyarrow.rs b/crates/polars-plan/src/logical_plan/pyarrow.rs index 91ba9aca3ea9b..82d83541de63c 100644 --- a/crates/polars-plan/src/logical_plan/pyarrow.rs +++ b/crates/polars-plan/src/logical_plan/pyarrow.rs @@ -13,11 +13,11 @@ pub(super) struct Args { } fn to_py_datetime(v: i64, tu: &TimeUnit, tz: Option<&TimeZone>) -> String { - // note: `_to_python_datetime` and the `Datetime` + // note: `to_py_datetime` and the `Datetime` // dtype have to be in-scope on the python side match tz { - None => format!("_to_python_datetime({},'{}')", v, tu.to_ascii()), - Some(tz) => format!("_to_python_datetime({},'{}',{})", v, tu.to_ascii(), tz), + None => format!("to_py_datetime({},'{}')", v, tu.to_ascii()), + Some(tz) => format!("to_py_datetime({},'{}',{})", v, tu.to_ascii(), tz), } } @@ -53,7 +53,7 @@ pub(super) fn predicate_to_pa( let dtm = to_py_datetime(v, &tu, tz.as_ref()); write!(list_repr, "{dtm},").unwrap(); } else if let AnyValue::Date(v) = av { - write!(list_repr, "_to_python_date({v}),").unwrap(); + write!(list_repr, "to_py_date({v}),").unwrap(); } else { write!(list_repr, "{av},").unwrap(); } @@ -79,25 +79,25 @@ pub(super) fn predicate_to_pa( }, #[cfg(feature = "dtype-date")] AnyValue::Date(v) => { - // the function `_to_python_date` and the `Date` + // the function `to_py_date` and the `Date` // dtype have to be in scope on the python side - Some(format!("_to_python_date({v})")) + Some(format!("to_py_date({v})")) }, #[cfg(feature = "dtype-datetime")] AnyValue::Datetime(v, tu, tz) => Some(to_py_datetime(v, &tu, tz.as_ref())), // Activate once pyarrow supports them // #[cfg(feature = "dtype-time")] // AnyValue::Time(v) => { - // // the function `_to_python_time` has to be in scope + // // the function `to_py_time` has to be in scope // // on the python side - // Some(format!("_to_python_time(value={v})")) + // Some(format!("to_py_time(value={v})")) // } // #[cfg(feature = "dtype-duration")] // AnyValue::Duration(v, tu) => { - // // the function `_to_python_timedelta` has to be in scope + // // the function `to_py_timedelta` has to be in scope // // on the python side // Some(format!( - // "_to_python_timedelta(value={}, tu='{}')", + // "to_py_timedelta(value={}, tu='{}')", // v, // tu.to_ascii() // )) diff --git a/crates/polars-plan/src/logical_plan/tree_format.rs b/crates/polars-plan/src/logical_plan/tree_format.rs index ebf045c1383e7..a4d8f8eb2f02f 100644 --- a/crates/polars-plan/src/logical_plan/tree_format.rs +++ b/crates/polars-plan/src/logical_plan/tree_format.rs @@ -260,7 +260,13 @@ impl<'a> TreeFmtNode<'a> { .collect(), ), NL(h, Distinct { input, options }) => ND( - wh(h, &format!("UNIQUE BY {:?}", options.subset)), + wh( + h, + &format!( + "UNIQUE[maintain_order: {:?}, keep_strategy: {:?}] BY {:?}", + options.maintain_order, options.keep_strategy, options.subset + ), + ), vec![NL(None, input)], ), NL(h, LogicalPlan::Slice { input, offset, len }) => ND( diff --git a/crates/polars-plan/src/prelude.rs b/crates/polars-plan/src/prelude.rs index 85da66d68b616..0b3f37cdfb222 100644 --- a/crates/polars-plan/src/prelude.rs +++ b/crates/polars-plan/src/prelude.rs @@ -9,11 +9,6 @@ pub(crate) use polars_time::in_nanoseconds_window; feature = "dtype-time" ))] pub(crate) use polars_time::prelude::*; -#[cfg(feature = "rolling_window")] -pub(crate) use polars_time::{ - chunkedarray::{RollingOptions, RollingOptionsImpl}, - Duration, -}; pub use polars_utils::arena::{Arena, Node}; pub use crate::dsl::*; diff --git a/crates/polars-plan/src/utils.rs b/crates/polars-plan/src/utils.rs index ff1118054a06a..4f88265a7dca7 100644 --- a/crates/polars-plan/src/utils.rs +++ b/crates/polars-plan/src/utils.rs @@ -1,13 +1,10 @@ use std::fmt::Formatter; use std::iter::FlatMap; -use std::sync::Arc; use polars_core::prelude::*; use polars_utils::idx_vec::UnitVec; use smartstring::alias::String as SmartString; -use crate::logical_plan::iterator::ArenaExprIter; -use crate::logical_plan::Context; use crate::prelude::consts::{LEN, LITERAL_NAME}; use crate::prelude::*; @@ -82,22 +79,17 @@ pub(crate) fn aexpr_is_simple_projection(current_node: Node, arena: &Arena) -> bool { - arena.iter(current_node).all(|(_node, e)| { - use AExpr::*; - match e { - AnonymousFunction { options, .. } | Function { options, .. } => { - !matches!(options.collect_groups, ApplyOptions::GroupWise) - }, - Column(_) - | Alias(_, _) - | Literal(_) - | BinaryExpr { .. } - | Ternary { .. } - | Cast { .. } => true, - _ => false, - } - }) +pub(crate) fn single_aexpr_is_elementwise(ae: &AExpr) -> bool { + use AExpr::*; + match ae { + AnonymousFunction { options, .. } | Function { options, .. } => { + !matches!(options.collect_groups, ApplyOptions::GroupWise) + }, + Column(_) | Alias(_, _) | Literal(_) | BinaryExpr { .. } | Ternary { .. } | Cast { .. } => { + true + }, + _ => false, + } } pub fn has_aexpr(current_node: Node, arena: &Arena, matches: F) -> bool @@ -159,6 +151,30 @@ pub fn has_null(current_expr: &Expr) -> bool { }) } +pub fn aexpr_output_name(node: Node, arena: &Arena) -> PolarsResult> { + for (_, ae) in arena.iter(node) { + match ae { + // don't follow the partition by branch + AExpr::Window { function, .. } => return aexpr_output_name(*function, arena), + AExpr::Column(name) => return Ok(name.clone()), + AExpr::Alias(_, name) => return Ok(name.clone()), + AExpr::Len => return Ok(Arc::from(LEN)), + AExpr::Literal(val) => { + return match val { + LiteralValue::Series(s) => Ok(Arc::from(s.name())), + _ => Ok(Arc::from(LITERAL_NAME)), + } + }, + _ => {}, + } + } + let expr = node_to_expr(node, arena); + polars_bail!( + ComputeError: + "unable to find root column name for expr '{expr:?}' when calling 'output_name'", + ); +} + /// output name of expr pub fn expr_output_name(expr: &Expr) -> PolarsResult> { for e in expr { @@ -292,9 +308,9 @@ pub(crate) fn aexpr_assign_renamed_leaf( current: &str, new_name: &str, ) -> Node { - let leafs = aexpr_to_column_nodes_iter(node, arena); + let leaf_nodes = aexpr_to_column_nodes_iter(node, arena); - for node in leafs { + for node in leaf_nodes { match arena.get(node) { AExpr::Column(name) if &**name == current => { return arena.add(AExpr::Column(Arc::from(new_name))) diff --git a/crates/polars-row/Cargo.toml b/crates/polars-row/Cargo.toml index cd764898b2d6c..1ca2138c87263 100644 --- a/crates/polars-row/Cargo.toml +++ b/crates/polars-row/Cargo.toml @@ -9,6 +9,7 @@ repository = { workspace = true } description = "Row encodings for the Polars DataFrame library" [dependencies] +bytemuck = { workspace = true } polars-error = { workspace = true } polars-utils = { workspace = true } diff --git a/crates/polars-row/src/encode.rs b/crates/polars-row/src/encode.rs index 8ed9a95cb5e86..bd94f33a203c9 100644 --- a/crates/polars-row/src/encode.rs +++ b/crates/polars-row/src/encode.rs @@ -2,8 +2,11 @@ use arrow::array::{ Array, BinaryArray, BinaryViewArray, BooleanArray, DictionaryArray, PrimitiveArray, StructArray, Utf8ViewArray, }; +use arrow::bitmap::utils::ZipValidity; use arrow::datatypes::ArrowDataType; +use arrow::legacy::prelude::{LargeBinaryArray, LargeListArray}; use arrow::types::NativeType; +use polars_utils::slice::GetSaferUnchecked; use polars_utils::vec::PushUnchecked; use crate::fixed::FixedLengthEncoding; @@ -30,6 +33,119 @@ pub fn convert_columns_amortized_no_order(columns: &[ArrayRef], rows: &mut RowsE ); } +enum Encoder { + // For list encoding we recursively call encode on the inner until we + // have a leaf we can encode. + // On allocation we already encode the leaves and set those to `rows`. + List { + enc: Vec, + rows: Option, + original: LargeListArray, + field: SortField, + }, + Leaf(ArrayRef), +} + +impl Encoder { + fn list_iter(&self) -> impl Iterator> { + match self { + Encoder::Leaf(_) => unreachable!(), + Encoder::List { original, rows, .. } => { + let rows = rows.as_ref().unwrap(); + // This should be 0 due to rows encoding; + assert_eq!(rows.null_count(), 0); + + let offsets = original.offsets().windows(2); + let zipped = ZipValidity::new_with_validity(offsets, original.validity()); + + let binary_offsets = rows.offsets(); + let row_values = rows.values().as_slice(); + + zipped.map(|opt_window| { + opt_window.map(|window| { + unsafe { + // Offsets of the list + let start = *window.get_unchecked_release(0); + let end = *window.get_unchecked_release(1); + + // Offsets in the binary values. + let start = *binary_offsets.get_unchecked_release(start as usize); + let end = *binary_offsets.get_unchecked_release(end as usize); + + let start = start as usize; + let end = end as usize; + + row_values.get_unchecked_release(start..end) + } + }) + }) + }, + } + } + + fn len(&self) -> usize { + match self { + Encoder::List { original, .. } => original.len(), + Encoder::Leaf(arr) => arr.len(), + } + } + + fn data_type(&self) -> &ArrowDataType { + match self { + Encoder::List { original, .. } => original.data_type(), + Encoder::Leaf(arr) => arr.data_type(), + } + } + + fn is_variable(&self) -> bool { + match self { + Encoder::Leaf(arr) => { + matches!( + arr.data_type(), + ArrowDataType::BinaryView + | ArrowDataType::Dictionary(_, _, _) + | ArrowDataType::LargeBinary + ) + }, + Encoder::List { .. } => true, + } + } +} + +fn get_encoders(arr: &dyn Array, encoders: &mut Vec, field: &SortField) -> usize { + let mut added = 0; + match arr.data_type() { + ArrowDataType::Struct(_) => { + let arr = arr.as_any().downcast_ref::().unwrap(); + for arr in arr.values() { + added += get_encoders(arr.as_ref(), encoders, field); + } + }, + ArrowDataType::Utf8View => { + let arr = arr.as_any().downcast_ref::().unwrap(); + encoders.push(Encoder::Leaf(arr.to_binview().boxed())); + added += 1 + }, + ArrowDataType::LargeList(_) => { + let arr = arr.as_any().downcast_ref::().unwrap(); + let mut inner = vec![]; + get_encoders(arr.values().as_ref(), &mut inner, field); + encoders.push(Encoder::List { + enc: inner, + original: arr.clone(), + rows: None, + field: field.clone(), + }); + added += 1; + }, + _ => { + encoders.push(Encoder::Leaf(arr.to_boxed())); + added += 1; + }, + } + added +} + pub fn convert_columns_amortized<'a, I: IntoIterator>( columns: &'a [ArrayRef], fields: I, @@ -40,47 +156,37 @@ pub fn convert_columns_amortized<'a, I: IntoIterator>( if columns.iter().any(|arr| { matches!( arr.data_type(), - ArrowDataType::Struct(_) | ArrowDataType::Utf8View + ArrowDataType::Struct(_) | ArrowDataType::Utf8View | ArrowDataType::LargeList(_) ) }) { let mut flattened_columns = Vec::with_capacity(columns.len() * 5); let mut flattened_fields = Vec::with_capacity(columns.len() * 5); for (arr, field) in columns.iter().zip(fields) { - match arr.data_type() { - ArrowDataType::Struct(_) => { - let arr = arr.as_any().downcast_ref::().unwrap(); - for arr in arr.values() { - flattened_columns.push(arr.clone() as ArrayRef); - flattened_fields.push(field.clone()) - } - }, - ArrowDataType::Utf8View => { - let arr = arr.as_any().downcast_ref::().unwrap(); - flattened_columns.push(arr.to_binview().boxed()); - flattened_fields.push(field.clone()); - }, - _ => { - flattened_columns.push(arr.clone()); - flattened_fields.push(field.clone()); - }, + let added = get_encoders(arr.as_ref(), &mut flattened_columns, field); + for _ in 0..added { + flattened_fields.push(field.clone()); } } let values_size = - allocate_rows_buf(&flattened_columns, &mut rows.values, &mut rows.offsets); + allocate_rows_buf(&mut flattened_columns, &mut rows.values, &mut rows.offsets); for (arr, field) in flattened_columns.iter().zip(flattened_fields.iter()) { // SAFETY: // we allocated rows with enough bytes. - unsafe { encode_array(&**arr, field, rows) } + unsafe { encode_array(arr, field, rows) } } // SAFETY: values are initialized unsafe { rows.values.set_len(values_size) } } else { - let values_size = allocate_rows_buf(columns, &mut rows.values, &mut rows.offsets); - for (arr, field) in columns.iter().zip(fields) { + let mut encoders = columns + .iter() + .map(|arr| Encoder::Leaf(arr.clone())) + .collect::>(); + let values_size = allocate_rows_buf(&mut encoders, &mut rows.values, &mut rows.offsets); + for (enc, field) in encoders.iter().zip(fields) { // SAFETY: // we allocated rows with enough bytes. - unsafe { encode_array(&**arr, field, rows) } + unsafe { encode_array(enc, field, rows) } } // SAFETY: values are initialized unsafe { rows.values.set_len(values_size) } @@ -105,41 +211,49 @@ fn encode_primitive( /// /// # Safety /// `out` must have enough bytes allocated otherwise it will be out of bounds. -unsafe fn encode_array(array: &dyn Array, field: &SortField, out: &mut RowsEncoded) { - match array.data_type() { - ArrowDataType::Boolean => { - let array = array.as_any().downcast_ref::().unwrap(); - crate::fixed::encode_iter(array.into_iter(), out, field); +unsafe fn encode_array(encoder: &Encoder, field: &SortField, out: &mut RowsEncoded) { + match encoder { + Encoder::List { .. } => { + let iter = encoder.list_iter(); + crate::variable::encode_iter(iter, out, &Default::default()) }, - ArrowDataType::LargeBinary => { - let array = array.as_any().downcast_ref::>().unwrap(); - crate::variable::encode_iter(array.into_iter(), out, field) - }, - ArrowDataType::BinaryView => { - let array = array.as_any().downcast_ref::().unwrap(); - crate::variable::encode_iter(array.into_iter(), out, field) - }, - ArrowDataType::LargeUtf8 | ArrowDataType::Utf8View => { - panic!("should be cast to binary") - }, - ArrowDataType::Dictionary(_, _, _) => { - let array = array - .as_any() - .downcast_ref::>() - .unwrap(); - let iter = array - .iter_typed::() - .unwrap() - .map(|opt_s| opt_s.map(|s| s.as_bytes())); - crate::variable::encode_iter(iter, out, field) - }, - dt => { - with_match_arrow_primitive_type!(dt, |$T| { - let array = array.as_any().downcast_ref::>().unwrap(); - encode_primitive(array, field, out); - }) + Encoder::Leaf(array) => { + match array.data_type() { + ArrowDataType::Boolean => { + let array = array.as_any().downcast_ref::().unwrap(); + crate::fixed::encode_iter(array.into_iter(), out, field); + }, + ArrowDataType::LargeBinary => { + let array = array.as_any().downcast_ref::>().unwrap(); + crate::variable::encode_iter(array.into_iter(), out, field) + }, + ArrowDataType::BinaryView => { + let array = array.as_any().downcast_ref::().unwrap(); + crate::variable::encode_iter(array.into_iter(), out, field) + }, + ArrowDataType::Utf8View => { + panic!("should be binview") + }, + ArrowDataType::Dictionary(_, _, _) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let iter = array + .iter_typed::() + .unwrap() + .map(|opt_s| opt_s.map(|s| s.as_bytes())); + crate::variable::encode_iter(iter, out, field) + }, + dt => { + with_match_arrow_primitive_type!(dt, |$T| { + let array = array.as_any().downcast_ref::>().unwrap(); + encode_primitive(array, field, out); + }) + }, + }; }, - }; + } } pub fn encoded_size(data_type: &ArrowDataType) -> usize { @@ -153,6 +267,7 @@ pub fn encoded_size(data_type: &ArrowDataType) -> usize { Int16 => i16::ENCODED_LEN, Int32 => i32::ENCODED_LEN, Int64 => i64::ENCODED_LEN, + Decimal(_, _) => i128::ENCODED_LEN, Float32 => f32::ENCODED_LEN, Float64 => f64::ENCODED_LEN, Boolean => bool::ENCODED_LEN, @@ -162,19 +277,12 @@ pub fn encoded_size(data_type: &ArrowDataType) -> usize { // Returns the length that the caller must set on the `values` buf once the bytes // are initialized. -pub fn allocate_rows_buf( - columns: &[ArrayRef], +fn allocate_rows_buf( + columns: &mut [Encoder], values: &mut Vec, offsets: &mut Vec, ) -> usize { - let has_variable = columns.iter().any(|arr| { - matches!( - arr.data_type(), - ArrowDataType::BinaryView - | ArrowDataType::Dictionary(_, _, _) - | ArrowDataType::LargeBinary - ) - }); + let has_variable = columns.iter().any(|enc| enc.is_variable()); let num_rows = columns[0].len(); if has_variable { @@ -182,16 +290,11 @@ pub fn allocate_rows_buf( // those can be determined without looping over the arrays let row_size_fixed: usize = columns .iter() - .map(|arr| { - if matches!( - arr.data_type(), - ArrowDataType::BinaryView - | ArrowDataType::Dictionary(_, _, _) - | ArrowDataType::LargeBinary - ) { + .map(|enc| { + if enc.is_variable() { 0 } else { - encoded_size(arr.data_type()) + encoded_size(enc.data_type()) } }) .sum(); @@ -204,57 +307,53 @@ pub fn allocate_rows_buf( // for the variable length columns we must iterate to determine the length per row location let mut processed_count = 0; - for array in columns.iter() { - match array.data_type() { - ArrowDataType::BinaryView => { - let array = array.as_any().downcast_ref::().unwrap(); - if processed_count == 0 { - for opt_val in array.into_iter() { - unsafe { - lengths.push_unchecked( - row_size_fixed + crate::variable::encoded_len(opt_val), - ); - } - } - } else { - for (opt_val, row_length) in array.into_iter().zip(lengths.iter_mut()) { - *row_length += crate::variable::encoded_len(opt_val) - } + for enc in columns.iter_mut() { + match enc { + Encoder::List { + enc: inner_enc, + rows, + field, + original, + } => { + // Nested lists don't yet work as that requires the leaves not only allocating, but also + // encoding. To make that work we must add a flag `in_list` that tell the leaves to immediately + // encode the rows instead of only setting the length. + // This needs a bit refactoring, might require allocation and encoding to be in + // the same function. + if let ArrowDataType::LargeList(inner) = original.data_type() { + assert!( + !matches!(inner.data_type, ArrowDataType::LargeList(_)), + "should not be nested" + ) } - processed_count += 1; - }, - ArrowDataType::LargeBinary => { - let array = array.as_any().downcast_ref::>().unwrap(); - if processed_count == 0 { - for opt_val in array.into_iter() { - unsafe { - lengths.push_unchecked( - row_size_fixed + crate::variable::encoded_len(opt_val), - ); - } - } - } else { - for (opt_val, row_length) in array.into_iter().zip(lengths.iter_mut()) { - *row_length += crate::variable::encoded_len(opt_val) + // Create the row encoding for the inner type. + let mut values_rows = RowsEncoded::default(); + + // Allocate and immediately row-encode the inner types recursively. + let values_size = allocate_rows_buf( + inner_enc, + &mut values_rows.values, + &mut values_rows.offsets, + ); + + // For single nested it does work as we encode here. + unsafe { + for enc in inner_enc { + encode_array(enc, field, &mut values_rows) } - } - processed_count += 1; - }, - ArrowDataType::Dictionary(_, _, _) => { - let array = array - .as_any() - .downcast_ref::>() - .unwrap(); - let iter = array - .iter_typed::() - .unwrap() - .map(|opt_s| opt_s.map(|s| s.as_bytes())); + values_rows.values.set_len(values_size) + }; + let values_rows = values_rows.into_array(); + *rows = Some(values_rows); + + let iter = enc.list_iter(); + if processed_count == 0 { for opt_val in iter { unsafe { lengths.push_unchecked( row_size_fixed + crate::variable::encoded_len(opt_val), - ) + ); } } } else { @@ -264,8 +363,74 @@ pub fn allocate_rows_buf( } processed_count += 1; }, - _ => { - // the rest is fixed + Encoder::Leaf(array) => { + match array.data_type() { + ArrowDataType::BinaryView => { + let array = array.as_any().downcast_ref::().unwrap(); + if processed_count == 0 { + for opt_val in array.into_iter() { + unsafe { + lengths.push_unchecked( + row_size_fixed + crate::variable::encoded_len(opt_val), + ); + } + } + } else { + for (opt_val, row_length) in + array.into_iter().zip(lengths.iter_mut()) + { + *row_length += crate::variable::encoded_len(opt_val) + } + } + processed_count += 1; + }, + ArrowDataType::LargeBinary => { + let array = array.as_any().downcast_ref::>().unwrap(); + if processed_count == 0 { + for opt_val in array.into_iter() { + unsafe { + lengths.push_unchecked( + row_size_fixed + crate::variable::encoded_len(opt_val), + ); + } + } + } else { + for (opt_val, row_length) in + array.into_iter().zip(lengths.iter_mut()) + { + *row_length += crate::variable::encoded_len(opt_val) + } + } + processed_count += 1; + }, + ArrowDataType::Dictionary(_, _, _) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let iter = array + .iter_typed::() + .unwrap() + .map(|opt_s| opt_s.map(|s| s.as_bytes())); + if processed_count == 0 { + for opt_val in iter { + unsafe { + lengths.push_unchecked( + row_size_fixed + crate::variable::encoded_len(opt_val), + ) + } + } + } else { + for (opt_val, row_length) in iter.zip(lengths.iter_mut()) { + *row_length += crate::variable::encoded_len(opt_val) + } + } + processed_count += 1; + }, + _ => { + // the rest is fixed + }, + } }, } } @@ -324,6 +489,7 @@ pub fn allocate_rows_buf( #[cfg(test)] mod test { use arrow::array::Int32Array; + use arrow::offset::Offsets; use super::*; use crate::decode::decode_rows_from_binary; @@ -433,4 +599,31 @@ mod test { assert_eq!(decoded, &a); } } + + #[test] + fn test_list_encode() { + let values = Utf8ViewArray::from_slice_values([ + "one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten", + ]); + let dtype = LargeListArray::default_datatype(values.data_type().clone()); + let array = LargeListArray::new( + dtype, + Offsets::::try_from(vec![0i64, 1, 4, 7, 7, 9, 10]) + .unwrap() + .into(), + values.boxed(), + None, + ); + let fields = &[SortField { + descending: true, + nulls_last: false, + }]; + + let out = convert_columns(&[array.boxed()], fields); + let out = out.into_array(); + assert_eq!( + out.values().iter().map(|v| *v as usize).sum::(), + 84981 + ); + } } diff --git a/crates/polars-row/src/fixed.rs b/crates/polars-row/src/fixed.rs index 3b0560fc0a10b..dfc9d6ff94f6b 100644 --- a/crates/polars-row/src/fixed.rs +++ b/crates/polars-row/src/fixed.rs @@ -12,7 +12,6 @@ use crate::row::{RowsEncoded, SortField}; pub(crate) trait FromSlice { fn from_slice(slice: &[u8]) -> Self; - fn from_slice_inverted(slice: &[u8]) -> Self; } impl FromSlice for [u8; N] { @@ -20,10 +19,6 @@ impl FromSlice for [u8; N] { fn from_slice(slice: &[u8]) -> Self { slice.try_into().unwrap() } - - fn from_slice_inverted(_slice: &[u8]) -> Self { - todo!() - } } /// Encodes a value of a particular fixed width type into bytes @@ -110,6 +105,7 @@ encode_signed!(1, i8); encode_signed!(2, i16); encode_signed!(4, i32); encode_signed!(8, i64); +encode_signed!(16, i128); impl FixedLengthEncoding for f32 { type Encoded = [u8; 4]; diff --git a/crates/polars-row/src/row.rs b/crates/polars-row/src/row.rs index 6752d065228fe..26ba4715d33b2 100644 --- a/crates/polars-row/src/row.rs +++ b/crates/polars-row/src/row.rs @@ -34,7 +34,7 @@ unsafe fn rows_to_array(buf: Vec, offsets: Vec) -> BinaryArray { checks(&offsets); // SAFETY: we checked overflow - let offsets = std::mem::transmute::, Vec>(offsets); + let offsets = bytemuck::cast_vec::(offsets); // SAFETY: monotonically increasing let offsets = Offsets::new_unchecked(offsets); @@ -60,14 +60,14 @@ impl RowsEncoded { /// Borrows the buffers and returns a [`BinaryArray`]. /// /// # Safety - /// The lifetime of that `BinaryArray` is tight to the lifetime of + /// The lifetime of that `BinaryArray` is tied to the lifetime of /// `Self`. The caller must ensure that both stay alive for the same time. pub unsafe fn borrow_array(&self) -> BinaryArray { checks(&self.offsets); unsafe { let (_, values, _) = mmap::slice(&self.values).into_inner(); - let offsets = std::mem::transmute::<&[usize], &[i64]>(self.offsets.as_slice()); + let offsets = bytemuck::cast_slice::(self.offsets.as_slice()); let (_, offsets, _) = mmap::slice(offsets).into_inner(); let offsets = OffsetsBuffer::new_unchecked(offsets); diff --git a/crates/polars-row/src/utils.rs b/crates/polars-row/src/utils.rs index 9ad339bced832..2681527f07fa6 100644 --- a/crates/polars-row/src/utils.rs +++ b/crates/polars-row/src/utils.rs @@ -9,6 +9,7 @@ macro_rules! with_match_arrow_primitive_type {( Int16 => __with_ty__! { i16 }, Int32 => __with_ty__! { i32 }, Int64 => __with_ty__! { i64 }, + Decimal(_, _) => __with_ty__! { i128 }, UInt8 => __with_ty__! { u8 }, UInt16 => __with_ty__! { u16 }, UInt32 => __with_ty__! { u32 }, diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index 8e67535e69429..4e84219c99634 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -5,7 +5,6 @@ use polars_core::prelude::*; use polars_error::to_compute_err; use polars_lazy::prelude::*; use polars_plan::prelude::*; -use polars_plan::utils::expressions_to_schema; use sqlparser::ast::{ Distinct, ExcludeSelectItem, Expr as SQLExpr, FunctionArg, GroupByExpr, JoinOperator, ObjectName, ObjectType, Offset, OrderByExpr, Query, Select, SelectItem, SetExpr, SetOperator, diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index 759e99c1385d2..4d5c74982b1de 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -3,10 +3,8 @@ use std::ops::Div; use polars_core::export::regex; use polars_core::prelude::*; use polars_error::to_compute_err; -use polars_lazy::dsl::Expr; use polars_lazy::prelude::*; use polars_plan::prelude::LiteralValue::Null; -use polars_plan::prelude::{col, lit, when}; use rand::distributions::Alphanumeric; use rand::{thread_rng, Rng}; #[cfg(feature = "dtype-decimal")] diff --git a/crates/polars-sql/tests/ops_distinct_on.rs b/crates/polars-sql/tests/ops_distinct_on.rs index f502192d51cc8..d9016b24a9b05 100644 --- a/crates/polars-sql/tests/ops_distinct_on.rs +++ b/crates/polars-sql/tests/ops_distinct_on.rs @@ -1,5 +1,4 @@ use polars_core::df; -use polars_core::prelude::*; use polars_lazy::prelude::*; use polars_sql::*; diff --git a/crates/polars-sql/tests/udf.rs b/crates/polars-sql/tests/udf.rs index fd52b63eea338..18b55629fb07e 100644 --- a/crates/polars-sql/tests/udf.rs +++ b/crates/polars-sql/tests/udf.rs @@ -1,8 +1,4 @@ -use std::sync::Arc; - -use polars_core::prelude::{DataType, Field, *}; -use polars_core::series::Series; -use polars_error::PolarsResult; +use polars_core::prelude::*; use polars_lazy::prelude::IntoLazy; use polars_plan::prelude::{GetOutput, UserDefinedFunction}; use polars_sql::function_registry::FunctionRegistry; diff --git a/crates/polars-time/src/chunkedarray/datetime.rs b/crates/polars-time/src/chunkedarray/datetime.rs index 9ce60778eba29..f0111be9047fa 100644 --- a/crates/polars-time/src/chunkedarray/datetime.rs +++ b/crates/polars-time/src/chunkedarray/datetime.rs @@ -1,4 +1,3 @@ -use arrow; use arrow::array::{Array, PrimitiveArray}; use arrow::compute::cast::{cast, CastOptions}; use arrow::compute::temporal; @@ -161,8 +160,6 @@ impl DatetimeMethods for DatetimeChunked {} #[cfg(test)] mod test { - use chrono::NaiveDateTime; - use super::*; #[test] diff --git a/crates/polars-time/src/chunkedarray/kernels.rs b/crates/polars-time/src/chunkedarray/kernels.rs index 5e453925f8f63..8526c180ad1a1 100644 --- a/crates/polars-time/src/chunkedarray/kernels.rs +++ b/crates/polars-time/src/chunkedarray/kernels.rs @@ -1,6 +1,6 @@ //! macros that define kernels for extracting //! `week`, `weekday`, `year`, `hour` etc. from primitive arrays. -use arrow::array::{ArrayRef, BooleanArray, PrimitiveArray}; +use arrow::array::{BooleanArray, PrimitiveArray}; use arrow::compute::arity::unary; #[cfg(feature = "dtype-time")] use arrow::temporal_conversions::time64ns_to_time_opt; @@ -8,7 +8,7 @@ use arrow::temporal_conversions::{ date32_to_datetime_opt, timestamp_ms_to_datetime_opt, timestamp_ns_to_datetime_opt, timestamp_us_to_datetime_opt, }; -use chrono::{Datelike, NaiveDate, NaiveDateTime, Timelike}; +use chrono::{Datelike, Timelike}; use super::super::windows::calendar::*; use super::*; diff --git a/crates/polars-time/src/chunkedarray/mod.rs b/crates/polars-time/src/chunkedarray/mod.rs index 029764fde2cdf..4c2fb9cbf5051 100644 --- a/crates/polars-time/src/chunkedarray/mod.rs +++ b/crates/polars-time/src/chunkedarray/mod.rs @@ -28,10 +28,6 @@ pub use string::StringMethods; #[cfg(feature = "dtype-time")] pub use time::TimeMethods; -pub fn unix_time() -> NaiveDateTime { - NaiveDateTime::from_timestamp_opt(0, 0).unwrap() -} - // a separate function so that it is not compiled twice #[cfg(any(feature = "dtype-date", feature = "dtype-datetime"))] pub(crate) fn months_to_quarters(mut ca: Int8Chunked) -> Int8Chunked { diff --git a/crates/polars-time/src/chunkedarray/rolling_window/mod.rs b/crates/polars-time/src/chunkedarray/rolling_window/mod.rs index 205233b15885a..5eb03d2bcc0fa 100644 --- a/crates/polars-time/src/chunkedarray/rolling_window/mod.rs +++ b/crates/polars-time/src/chunkedarray/rolling_window/mod.rs @@ -1,8 +1,6 @@ mod dispatch; mod rolling_kernels; -use std::convert::TryFrom; - use arrow::array::{Array, ArrayRef, PrimitiveArray}; use arrow::legacy::kernels::rolling; pub use dispatch::*; diff --git a/crates/polars-time/src/chunkedarray/rolling_window/rolling_kernels/mod.rs b/crates/polars-time/src/chunkedarray/rolling_window/rolling_kernels/mod.rs index 3641bb4f99313..1106a679e30f3 100644 --- a/crates/polars-time/src/chunkedarray/rolling_window/rolling_kernels/mod.rs +++ b/crates/polars-time/src/chunkedarray/rolling_window/rolling_kernels/mod.rs @@ -3,7 +3,6 @@ use std::fmt::Debug; use std::ops::{AddAssign, Mul, SubAssign}; use arrow::array::{ArrayRef, PrimitiveArray}; -use arrow::legacy::index::IdxSize; use arrow::trusted_len::TrustedLen; use arrow::types::NativeType; use polars_core::export::num::{Bounded, Float, NumCast}; diff --git a/crates/polars-time/src/chunkedarray/rolling_window/rolling_kernels/no_nulls.rs b/crates/polars-time/src/chunkedarray/rolling_window/rolling_kernels/no_nulls.rs index 8c02b38625a72..abd4eadffc79e 100644 --- a/crates/polars-time/src/chunkedarray/rolling_window/rolling_kernels/no_nulls.rs +++ b/crates/polars-time/src/chunkedarray/rolling_window/rolling_kernels/no_nulls.rs @@ -41,7 +41,7 @@ where } else { // SAFETY: // we are in bounds - Some(unsafe { agg_window.update(start as usize, end as usize) }) + unsafe { agg_window.update(start as usize, end as usize) } } }) }) diff --git a/crates/polars-time/src/date_range.rs b/crates/polars-time/src/date_range.rs index bec0d942f149f..b44afebddb32c 100644 --- a/crates/polars-time/src/date_range.rs +++ b/crates/polars-time/src/date_range.rs @@ -23,11 +23,17 @@ pub fn date_range( ) -> PolarsResult { let (start, end) = match tu { TimeUnit::Nanoseconds => ( - start.timestamp_nanos_opt().unwrap(), - end.timestamp_nanos_opt().unwrap(), + start.and_utc().timestamp_nanos_opt().unwrap(), + end.and_utc().timestamp_nanos_opt().unwrap(), + ), + TimeUnit::Microseconds => ( + start.and_utc().timestamp_micros(), + end.and_utc().timestamp_micros(), + ), + TimeUnit::Milliseconds => ( + start.and_utc().timestamp_millis(), + end.and_utc().timestamp_millis(), ), - TimeUnit::Microseconds => (start.timestamp_micros(), end.timestamp_micros()), - TimeUnit::Milliseconds => (start.timestamp_millis(), end.timestamp_millis()), }; datetime_range_impl(name, start, end, interval, closed, tu, tz) } diff --git a/crates/polars-time/src/group_by/dynamic.rs b/crates/polars-time/src/group_by/dynamic.rs index f4baa0f485732..4510b65ee54f0 100644 --- a/crates/polars-time/src/group_by/dynamic.rs +++ b/crates/polars-time/src/group_by/dynamic.rs @@ -1,7 +1,6 @@ use arrow::legacy::time_zone::Tz; use arrow::legacy::utils::CustomIterTools; use polars_core::export::rayon::prelude::*; -use polars_core::frame::group_by::GroupsProxy; use polars_core::prelude::*; use polars_core::series::IsSorted; use polars_core::utils::ensure_sorted_arg; @@ -151,18 +150,18 @@ impl Wrap<&DataFrame> { TimeUnit::Milliseconds, None, ), - Int32 => { - let time_type = Datetime(TimeUnit::Nanoseconds, None); - let dt = time.cast(&Int64).unwrap().cast(&time_type).unwrap(); + UInt32 | UInt64 | Int32 => { + let time_type_dt = Datetime(TimeUnit::Nanoseconds, None); + let dt = time.cast(&Int64).unwrap().cast(&time_type_dt).unwrap(); let (out, by, gt) = self.impl_group_by_rolling( dt, by, options, TimeUnit::Nanoseconds, None, - &time_type, + &time_type_dt, )?; - let out = out.cast(&Int64).unwrap().cast(&Int32).unwrap(); + let out = out.cast(&Int64).unwrap().cast(time_type).unwrap(); return Ok((out, by, gt)); }, Int64 => { @@ -803,11 +802,13 @@ mod test { .unwrap() .and_hms_opt(0, 0, 0) .unwrap() + .and_utc() .timestamp_millis(); let stop = NaiveDate::from_ymd_opt(2021, 12, 16) .unwrap() .and_hms_opt(3, 0, 0) .unwrap() + .and_utc() .timestamp_millis(); let range = datetime_range_impl( "date", @@ -856,11 +857,13 @@ mod test { .unwrap() .and_hms_opt(1, 0, 0) .unwrap() + .and_utc() .timestamp_millis(); let stop = NaiveDate::from_ymd_opt(2021, 12, 16) .unwrap() .and_hms_opt(3, 0, 0) .unwrap() + .and_utc() .timestamp_millis(); let range = datetime_range_impl( "_upper_boundary", @@ -879,11 +882,13 @@ mod test { .unwrap() .and_hms_opt(0, 0, 0) .unwrap() + .and_utc() .timestamp_millis(); let stop = NaiveDate::from_ymd_opt(2021, 12, 16) .unwrap() .and_hms_opt(2, 0, 0) .unwrap() + .and_utc() .timestamp_millis(); let range = datetime_range_impl( "_lower_boundary", @@ -918,11 +923,13 @@ mod test { .unwrap() .and_hms_opt(12, 0, 0) .unwrap() + .and_utc() .timestamp_millis(); let stop = NaiveDate::from_ymd_opt(2021, 3, 7) .unwrap() .and_hms_opt(12, 0, 0) .unwrap() + .and_utc() .timestamp_millis(); let range = datetime_range_impl( "date", diff --git a/crates/polars-time/src/month_start.rs b/crates/polars-time/src/month_start.rs index cc68463625628..79934fe01cf69 100644 --- a/crates/polars-time/src/month_start.rs +++ b/crates/polars-time/src/month_start.rs @@ -49,7 +49,10 @@ pub(crate) fn roll_backward( let ndt = NaiveDateTime::new(date, time); let t = match tz { #[cfg(feature = "timezones")] - Some(tz) => datetime_to_timestamp(try_localize_datetime(ndt, tz, Ambiguous::Raise)?), + Some(tz) => datetime_to_timestamp( + try_localize_datetime(ndt, tz, Ambiguous::Raise)? + .expect("we didn't use Ambiguous::Null"), + ), _ => datetime_to_timestamp(ndt), }; Ok(t) diff --git a/crates/polars-time/src/utils.rs b/crates/polars-time/src/utils.rs index 4d248e700e4cf..f5ca4f4dfeeff 100644 --- a/crates/polars-time/src/utils.rs +++ b/crates/polars-time/src/utils.rs @@ -9,13 +9,20 @@ use chrono::TimeZone; #[cfg(feature = "timezones")] use polars_core::prelude::PolarsResult; +/// Localize datetime according to given time zone. +/// +/// e.g. '2021-01-01 03:00' -> '2021-01-01 03:00CDT' +/// +/// Note: this may only return `Ok(None)` if ambiguous is Ambiguous::Null. +/// Otherwise, it will either return `Ok(Some(NaiveDateTime))` or `PolarsError`. +/// Therefore, calling `try_localize_datetime(..., Ambiguous::Raise)?.unwrap()` +/// is safe, and will never panic. #[cfg(feature = "timezones")] pub(crate) fn try_localize_datetime( ndt: NaiveDateTime, tz: &Tz, ambiguous: Ambiguous, -) -> PolarsResult { - // e.g. '2021-01-01 03:00' -> '2021-01-01 03:00CDT' +) -> PolarsResult> { convert_to_naive_local(&chrono_tz::UTC, tz, ndt, ambiguous) } @@ -24,7 +31,7 @@ pub(crate) fn localize_datetime_opt( ndt: NaiveDateTime, tz: &Tz, ambiguous: Ambiguous, -) -> Option { +) -> Option> { // e.g. '2021-01-01 03:00' -> '2021-01-01 03:00CDT' convert_to_naive_local_opt(&chrono_tz::UTC, tz, ndt, ambiguous) } diff --git a/crates/polars-time/src/windows/duration.rs b/crates/polars-time/src/windows/duration.rs index 9f9640af4d6b9..a843286ea5762 100644 --- a/crates/polars-time/src/windows/duration.rs +++ b/crates/polars-time/src/windows/duration.rs @@ -440,7 +440,7 @@ impl Duration { ) } - /// Localize result to given time zone., respecting DST fold of original datetime. + /// Localize result to given time zone, respecting DST fold of original datetime. /// For example, 2022-11-06 01:30:00 CST truncated by 1 hour becomes 2022-11-06 01:00:00 CST, /// whereas 2022-11-06 01:30:00 CDT truncated by 1 hour becomes 2022-11-06 01:00:00 CDT. /// @@ -458,18 +458,26 @@ impl Duration { original_dt_utc: NaiveDateTime, result_dt_local: NaiveDateTime, tz: &Tz, - ) -> NaiveDateTime { + ) -> PolarsResult { match localize_datetime_opt(result_dt_local, tz, Ambiguous::Raise) { - Some(dt) => dt, + Some(dt) => Ok(dt.expect("we didn't use Ambiguous::Null")), None => { - if try_localize_datetime(original_dt_local, tz, Ambiguous::Earliest).unwrap() + if try_localize_datetime(original_dt_local, tz, Ambiguous::Earliest)? + .expect("we didn't use Ambiguous::Null") == original_dt_utc { - try_localize_datetime(result_dt_local, tz, Ambiguous::Earliest).unwrap() - } else if try_localize_datetime(original_dt_local, tz, Ambiguous::Latest).unwrap() + Ok( + try_localize_datetime(result_dt_local, tz, Ambiguous::Earliest)? + .expect("we didn't use Ambiguous::Null"), + ) + } else if try_localize_datetime(original_dt_local, tz, Ambiguous::Latest)? + .expect("we didn't use Ambiguous::Null") == original_dt_utc { - try_localize_datetime(result_dt_local, tz, Ambiguous::Latest).unwrap() + Ok( + try_localize_datetime(result_dt_local, tz, Ambiguous::Latest)? + .expect("we didn't use Ambiguous::Null"), + ) } else { unreachable!() } @@ -503,7 +511,7 @@ impl Duration { let result_timestamp = t - remainder; let result_dt_local = _timestamp_to_datetime(result_timestamp); let result_dt_utc = - self.localize_result(original_dt_local, original_dt_utc, result_dt_local, tz); + self.localize_result(original_dt_local, original_dt_utc, result_dt_local, tz)?; Ok(_datetime_to_timestamp(result_dt_utc)) }, _ => { @@ -564,7 +572,7 @@ impl Duration { _original_dt_utc.unwrap(), result_dt_local, tz, - ); + )?; Ok(_datetime_to_timestamp(result_dt_utc)) }, _ => Ok(result_t_local), @@ -647,7 +655,7 @@ impl Duration { Some(tz) if tz != &chrono_tz::UTC => { let result_dt_local = timestamp_to_datetime(t - remainder_days * daily_duration); let result_dt_utc = - self.localize_result(original_dt_local, original_dt_utc, result_dt_local, tz); + self.localize_result(original_dt_local, original_dt_utc, result_dt_local, tz)?; Ok(datetime_to_timestamp(result_dt_utc)) }, _ => Ok(t - remainder_days * daily_duration), @@ -785,9 +793,10 @@ impl Duration { new_t = match tz { #[cfg(feature = "timezones")] // for UTC, use fastpath below (same as naive) - Some(tz) if tz != &chrono_tz::UTC => { - datetime_to_timestamp(try_localize_datetime(dt, tz, Ambiguous::Raise)?) - }, + Some(tz) if tz != &chrono_tz::UTC => datetime_to_timestamp( + try_localize_datetime(dt, tz, Ambiguous::Raise)? + .expect("we didn't use Ambiguous::Null"), + ), _ => datetime_to_timestamp(dt), }; } @@ -801,11 +810,10 @@ impl Duration { new_t = datetime_to_timestamp(unlocalize_datetime(timestamp_to_datetime(t), tz)); new_t += if d.negative { -t_weeks } else { t_weeks }; - new_t = datetime_to_timestamp(try_localize_datetime( - timestamp_to_datetime(new_t), - tz, - Ambiguous::Raise, - )?); + new_t = datetime_to_timestamp( + try_localize_datetime(timestamp_to_datetime(new_t), tz, Ambiguous::Raise)? + .expect("we didn't use Ambiguous::Null"), + ); }, _ => new_t += if d.negative { -t_weeks } else { t_weeks }, }; @@ -820,11 +828,10 @@ impl Duration { new_t = datetime_to_timestamp(unlocalize_datetime(timestamp_to_datetime(t), tz)); new_t += if d.negative { -t_days } else { t_days }; - new_t = datetime_to_timestamp(try_localize_datetime( - timestamp_to_datetime(new_t), - tz, - Ambiguous::Raise, - )?); + new_t = datetime_to_timestamp( + try_localize_datetime(timestamp_to_datetime(new_t), tz, Ambiguous::Raise)? + .expect("we didn't use Ambiguous::Null"), + ); }, _ => new_t += if d.negative { -t_days } else { t_days }, }; diff --git a/crates/polars-time/src/windows/test.rs b/crates/polars-time/src/windows/test.rs index 6fc9c663efe24..7b573c14d49fd 100644 --- a/crates/polars-time/src/windows/test.rs +++ b/crates/polars-time/src/windows/test.rs @@ -2,7 +2,6 @@ use arrow::temporal_conversions::timestamp_ns_to_datetime; use chrono::prelude::*; use polars_core::prelude::*; -use crate::date_range::datetime_range_i64; use crate::prelude::*; #[test] @@ -17,8 +16,8 @@ fn test_date_range() { .and_hms_opt(0, 0, 0) .unwrap(); let dates = datetime_range_i64( - start.timestamp_nanos_opt().unwrap(), - end.timestamp_nanos_opt().unwrap(), + start.and_utc().timestamp_nanos_opt().unwrap(), + end.and_utc().timestamp_nanos_opt().unwrap(), Duration::parse("1mo"), ClosedWindow::Both, TimeUnit::Nanoseconds, @@ -35,6 +34,7 @@ fn test_date_range() { .map(|d| { d.and_hms_opt(0, 0, 0) .unwrap() + .and_utc() .timestamp_nanos_opt() .unwrap() }) @@ -53,8 +53,8 @@ fn test_feb_date_range() { .and_hms_opt(0, 0, 0) .unwrap(); let dates = datetime_range_i64( - start.timestamp_nanos_opt().unwrap(), - end.timestamp_nanos_opt().unwrap(), + start.and_utc().timestamp_nanos_opt().unwrap(), + end.and_utc().timestamp_nanos_opt().unwrap(), Duration::parse("1mo"), ClosedWindow::Both, TimeUnit::Nanoseconds, @@ -69,6 +69,7 @@ fn test_feb_date_range() { .map(|d| { d.and_hms_opt(0, 0, 0) .unwrap() + .and_utc() .timestamp_nanos_opt() .unwrap() }) @@ -102,6 +103,7 @@ fn test_groups_large_interval() { .map(|d| { d.and_hms_opt(0, 0, 0) .unwrap() + .and_utc() .timestamp_nanos_opt() .unwrap() }) @@ -156,6 +158,7 @@ fn test_offset() { .unwrap() .and_hms_opt(0, 0, 0) .unwrap() + .and_utc() .timestamp_nanos_opt() .unwrap(); let w = Window::new( @@ -169,6 +172,7 @@ fn test_offset() { .unwrap() .and_hms_opt(23, 58, 0) .unwrap() + .and_utc() .timestamp_nanos_opt() .unwrap(); assert_eq!(b.start, start); @@ -186,8 +190,8 @@ fn test_boundaries() { .unwrap(); let ts = datetime_range_i64( - start.timestamp_nanos_opt().unwrap(), - stop.timestamp_nanos_opt().unwrap(), + start.and_utc().timestamp_nanos_opt().unwrap(), + stop.and_utc().timestamp_nanos_opt().unwrap(), Duration::parse("30m"), ClosedWindow::Both, TimeUnit::Nanoseconds, @@ -206,7 +210,7 @@ fn test_boundaries() { // earliest bound is first datapoint: 2021-12-16 00:00:00 let b = w.get_earliest_bounds_ns(ts[0], None).unwrap(); - assert_eq!(b.start, start.timestamp_nanos_opt().unwrap()); + assert_eq!(b.start, start.and_utc().timestamp_nanos_opt().unwrap()); // test closed: "both" (includes both ends of the interval) let (groups, lower, higher) = group_by_windows( @@ -243,9 +247,9 @@ fn test_boundaries() { assert_eq!( g, &[ - t0.timestamp_nanos_opt().unwrap(), - t1.timestamp_nanos_opt().unwrap(), - t2.timestamp_nanos_opt().unwrap() + t0.and_utc().timestamp_nanos_opt().unwrap(), + t1.and_utc().timestamp_nanos_opt().unwrap(), + t2.and_utc().timestamp_nanos_opt().unwrap() ] ); let b_start = NaiveDate::from_ymd_opt(2021, 12, 16) @@ -259,8 +263,8 @@ fn test_boundaries() { assert_eq!( &[lower[0], higher[0]], &[ - b_start.timestamp_nanos_opt().unwrap(), - b_end.timestamp_nanos_opt().unwrap() + b_start.and_utc().timestamp_nanos_opt().unwrap(), + b_end.and_utc().timestamp_nanos_opt().unwrap() ] ); @@ -287,9 +291,9 @@ fn test_boundaries() { assert_eq!( g, &[ - t0.timestamp_nanos_opt().unwrap(), - t1.timestamp_nanos_opt().unwrap(), - t2.timestamp_nanos_opt().unwrap() + t0.and_utc().timestamp_nanos_opt().unwrap(), + t1.and_utc().timestamp_nanos_opt().unwrap(), + t2.and_utc().timestamp_nanos_opt().unwrap() ] ); let b_start = NaiveDate::from_ymd_opt(2021, 12, 16) @@ -303,8 +307,8 @@ fn test_boundaries() { assert_eq!( &[lower[1], higher[1]], &[ - b_start.timestamp_nanos_opt().unwrap(), - b_end.timestamp_nanos_opt().unwrap() + b_start.and_utc().timestamp_nanos_opt().unwrap(), + b_end.and_utc().timestamp_nanos_opt().unwrap() ] ); @@ -368,8 +372,8 @@ fn test_boundaries_2() { .unwrap(); let ts = datetime_range_i64( - start.timestamp_nanos_opt().unwrap(), - stop.timestamp_nanos_opt().unwrap(), + start.and_utc().timestamp_nanos_opt().unwrap(), + stop.and_utc().timestamp_nanos_opt().unwrap(), Duration::parse("30m"), ClosedWindow::Both, TimeUnit::Nanoseconds, @@ -391,7 +395,7 @@ fn test_boundaries_2() { assert_eq!( b.start, - start.timestamp_nanos_opt().unwrap() + offset.duration_ns() + start.and_utc().timestamp_nanos_opt().unwrap() + offset.duration_ns() ); let (groups, lower, higher) = group_by_windows( @@ -425,8 +429,8 @@ fn test_boundaries_2() { assert_eq!( g, &[ - t0.timestamp_nanos_opt().unwrap(), - t1.timestamp_nanos_opt().unwrap() + t0.and_utc().timestamp_nanos_opt().unwrap(), + t1.and_utc().timestamp_nanos_opt().unwrap() ] ); let b_start = NaiveDate::from_ymd_opt(2021, 12, 16) @@ -440,8 +444,8 @@ fn test_boundaries_2() { assert_eq!( &[lower[0], higher[0]], &[ - b_start.timestamp_nanos_opt().unwrap(), - b_end.timestamp_nanos_opt().unwrap() + b_start.and_utc().timestamp_nanos_opt().unwrap(), + b_end.and_utc().timestamp_nanos_opt().unwrap() ] ); @@ -464,8 +468,8 @@ fn test_boundaries_2() { assert_eq!( g, &[ - t0.timestamp_nanos_opt().unwrap(), - t1.timestamp_nanos_opt().unwrap() + t0.and_utc().timestamp_nanos_opt().unwrap(), + t1.and_utc().timestamp_nanos_opt().unwrap() ] ); let b_start = NaiveDate::from_ymd_opt(2021, 12, 16) @@ -479,8 +483,8 @@ fn test_boundaries_2() { assert_eq!( &[lower[1], higher[1]], &[ - b_start.timestamp_nanos_opt().unwrap(), - b_end.timestamp_nanos_opt().unwrap() + b_start.and_utc().timestamp_nanos_opt().unwrap(), + b_end.and_utc().timestamp_nanos_opt().unwrap() ] ); } @@ -497,8 +501,8 @@ fn test_boundaries_ms() { .unwrap(); let ts = datetime_range_i64( - start.timestamp_millis(), - stop.timestamp_millis(), + start.and_utc().timestamp_millis(), + stop.and_utc().timestamp_millis(), Duration::parse("30m"), ClosedWindow::Both, TimeUnit::Milliseconds, @@ -517,7 +521,7 @@ fn test_boundaries_ms() { // earliest bound is first datapoint: 2021-12-16 00:00:00 let b = w.get_earliest_bounds_ms(ts[0], None).unwrap(); - assert_eq!(b.start, start.timestamp_millis()); + assert_eq!(b.start, start.and_utc().timestamp_millis()); // test closed: "both" (includes both ends of the interval) let (groups, lower, higher) = group_by_windows( @@ -554,9 +558,9 @@ fn test_boundaries_ms() { assert_eq!( g, &[ - t0.timestamp_millis(), - t1.timestamp_millis(), - t2.timestamp_millis() + t0.and_utc().timestamp_millis(), + t1.and_utc().timestamp_millis(), + t2.and_utc().timestamp_millis() ] ); let b_start = NaiveDate::from_ymd_opt(2021, 12, 16) @@ -569,7 +573,10 @@ fn test_boundaries_ms() { .unwrap(); assert_eq!( &[lower[0], higher[0]], - &[b_start.timestamp_millis(), b_end.timestamp_millis()] + &[ + b_start.and_utc().timestamp_millis(), + b_end.and_utc().timestamp_millis() + ] ); // 2nd group @@ -595,9 +602,9 @@ fn test_boundaries_ms() { assert_eq!( g, &[ - t0.timestamp_millis(), - t1.timestamp_millis(), - t2.timestamp_millis() + t0.and_utc().timestamp_millis(), + t1.and_utc().timestamp_millis(), + t2.and_utc().timestamp_millis() ] ); let b_start = NaiveDate::from_ymd_opt(2021, 12, 16) @@ -610,7 +617,10 @@ fn test_boundaries_ms() { .unwrap(); assert_eq!( &[lower[1], higher[1]], - &[b_start.timestamp_millis(), b_end.timestamp_millis()] + &[ + b_start.and_utc().timestamp_millis(), + b_end.and_utc().timestamp_millis() + ] ); assert_eq!(groups[2], [4, 3]); @@ -673,8 +683,8 @@ fn test_rolling_lookback() { .and_hms_opt(4, 0, 0) .unwrap(); let dates = datetime_range_i64( - start.timestamp_millis(), - end.timestamp_millis(), + start.and_utc().timestamp_millis(), + end.and_utc().timestamp_millis(), Duration::parse("30m"), ClosedWindow::Both, TimeUnit::Milliseconds, @@ -788,11 +798,13 @@ fn test_end_membership() { .unwrap() .and_hms_opt(0, 0, 0) .unwrap() + .and_utc() .timestamp_millis(), NaiveDate::from_ymd_opt(2021, 5, 1) .unwrap() .and_hms_opt(0, 0, 0) .unwrap() + .and_utc() .timestamp_millis(), ]; let window = Window::new( @@ -879,6 +891,7 @@ fn test_group_by_windows_offsets_3776() { .map(|d| { d.and_hms_opt(0, 0, 0) .unwrap() + .and_utc() .timestamp_nanos_opt() .unwrap() }) diff --git a/crates/polars-time/src/windows/window.rs b/crates/polars-time/src/windows/window.rs index 3666013a18e3b..8adb7520ecfec 100644 --- a/crates/polars-time/src/windows/window.rs +++ b/crates/polars-time/src/windows/window.rs @@ -5,7 +5,6 @@ use chrono::NaiveDateTime; use chrono::TimeZone; use now::DateTimeNow; use polars_core::prelude::*; -use polars_core::utils::arrow::temporal_conversions::timeunit_scale; use crate::prelude::*; diff --git a/crates/polars-utils/Cargo.toml b/crates/polars-utils/Cargo.toml index 2afe25a87a7f6..e2810b67daaf0 100644 --- a/crates/polars-utils/Cargo.toml +++ b/crates/polars-utils/Cargo.toml @@ -17,10 +17,14 @@ hashbrown = { workspace = true } indexmap = { workspace = true } num-traits = { workspace = true } once_cell = { workspace = true } +raw-cpuid = { workspace = true } rayon = { workspace = true } smartstring = { workspace = true } sysinfo = { version = "0.30", default-features = false, optional = true } +[dev-dependencies] +rand = { workspace = true } + [build-dependencies] version_check = { workspace = true } diff --git a/crates/polars-utils/src/arena.rs b/crates/polars-utils/src/arena.rs index d44689c0f836b..df367b733f1f2 100644 --- a/crates/polars-utils/src/arena.rs +++ b/crates/polars-utils/src/arena.rs @@ -18,6 +18,7 @@ fn index_of(slice: &[T], item: &T) -> Option { } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)] +#[repr(transparent)] pub struct Node(pub usize); impl Default for Node { diff --git a/crates/polars-utils/src/clmul.rs b/crates/polars-utils/src/clmul.rs new file mode 100644 index 0000000000000..eb9e03f536c9e --- /dev/null +++ b/crates/polars-utils/src/clmul.rs @@ -0,0 +1,151 @@ +#[cfg(all(target_arch = "x86_64", target_feature = "pclmulqdq"))] +fn intel_clmul64(x: u64, y: u64) -> u64 { + use core::arch::x86_64::*; + unsafe { + // SAFETY: we have the target feature. + _mm_cvtsi128_si64(_mm_clmulepi64_si128( + _mm_cvtsi64_si128(x as i64), + _mm_cvtsi64_si128(y as i64), + 0, + )) as u64 + } +} + +#[cfg(all( + target_arch = "aarch64", + target_feature = "neon", + target_feature = "aes" +))] +fn arm_clmul64(x: u64, y: u64) -> u64 { + unsafe { + // SAFETY: we have the target feature. + use core::arch::aarch64::*; + vmull_p64(x, y) as u64 + } +} + +#[inline] +pub fn portable_clmul64(x: u64, mut y: u64) -> u64 { + let mut out = 0; + while y > 0 { + let lsb = y & y.wrapping_neg(); + out ^= x.wrapping_mul(lsb); + y ^= lsb; + } + out +} + +// Computes the carryless multiplication of x and y. +#[inline] +pub fn clmul64(x: u64, y: u64) -> u64 { + #[cfg(all(target_arch = "x86_64", target_feature = "pclmulqdq"))] + return intel_clmul64(x, y); + + #[cfg(all( + target_arch = "aarch64", + target_feature = "neon", + target_feature = "aes" + ))] + return arm_clmul64(x, y); + + #[allow(unreachable_code)] + portable_clmul64(x, y) +} + +#[inline] +pub fn portable_prefix_xorsum(mut x: u64) -> u64 { + x <<= 1; + for i in 0..6 { + x ^= x << (1 << i); + } + x +} + +// Computes for each bit i the XOR of all less significant bits. +#[inline] +pub fn prefix_xorsum(x: u64) -> u64 { + #[cfg(all(target_arch = "x86_64", target_feature = "pclmulqdq"))] + return intel_clmul64(x, u64::MAX ^ 1); + + #[cfg(all( + target_arch = "aarch64", + target_feature = "neon", + target_feature = "aes" + ))] + return arm_clmul64(x, u64::MAX ^ 1); + + #[allow(unreachable_code)] + portable_prefix_xorsum(x) +} + +#[cfg(test)] +mod test { + use rand::prelude::*; + + use super::*; + + #[test] + fn test_clmul() { + // Verify platform-specific clmul to portable. + let mut rng = StdRng::seed_from_u64(0xdeadbeef); + for _ in 0..100 { + let x = rng.gen(); + let y = rng.gen(); + assert_eq!(portable_clmul64(x, y), clmul64(x, y)); + } + + // Verify portable clmul for known test vectors. + assert_eq!( + portable_clmul64(0x8b44729195dde0ef, 0xb976c5ae2726fab0), + 0x4ae14eae84899290 + ); + assert_eq!( + portable_clmul64(0x399b6ed00c44b301, 0x693341db5acb2ff0), + 0x48dfa88344823ff0 + ); + assert_eq!( + portable_clmul64(0xdf4c9f6e60deb640, 0x6d4bcdb217ac4880), + 0x7300ffe474792000 + ); + assert_eq!( + portable_clmul64(0xa7adf3c53a200a51, 0x818cb40fe11b431e), + 0x6a280181d521797e + ); + assert_eq!( + portable_clmul64(0x5e78e12b744f228c, 0x4225ff19e9273266), + 0xa48b73cafb9665a8 + ); + } + + #[test] + fn test_prefix_xorsum() { + // Verify platform-specific prefix_xorsum to portable. + let mut rng = StdRng::seed_from_u64(0xdeadbeef); + for _ in 0..100 { + let x = rng.gen(); + assert_eq!(portable_prefix_xorsum(x), prefix_xorsum(x)); + } + + // Verify portable prefix_xorsum for known test vectors. + assert_eq!( + portable_prefix_xorsum(0x8b44729195dde0ef), + 0x0d87a31ee696bf4a + ); + assert_eq!( + portable_prefix_xorsum(0xb976c5ae2726fab0), + 0x2e5b79343a3b5320 + ); + assert_eq!( + portable_prefix_xorsum(0x399b6ed00c44b301), + 0xd1124b600878ddfe + ); + assert_eq!( + portable_prefix_xorsum(0x693341db5acb2ff0), + 0x4e227e926c8dcaa0 + ); + assert_eq!( + portable_prefix_xorsum(0xdf4c9f6e60deb640), + 0x6a7715b44094db80 + ); + } +} diff --git a/crates/polars-utils/src/cpuid.rs b/crates/polars-utils/src/cpuid.rs new file mode 100644 index 0000000000000..71eda848a878a --- /dev/null +++ b/crates/polars-utils/src/cpuid.rs @@ -0,0 +1,39 @@ +// So much conditional stuff going on here... +#![allow(dead_code, unreachable_code, unused)] + +use std::sync::OnceLock; + +#[cfg(target_arch = "x86_64")] +use raw_cpuid::CpuId; + +#[cfg(target_feature = "bmi2")] +#[inline(never)] +#[cold] +fn detect_fast_bmi2() -> bool { + let cpu_id = CpuId::new(); + let vendor = cpu_id.get_vendor_info().expect("could not read cpu vendor"); + if vendor.as_str() == "AuthenticAMD" || vendor.as_str() == "HygonGenuine" { + let features = cpu_id + .get_feature_info() + .expect("could not read cpu feature info"); + let family_id = features.family_id(); + + // Hardcoded blacklist of known-bad AMD families. + // We'll assume any future releases that support BMI2 have a + // proper implementation. + !(family_id >= 0x15 && family_id <= 0x18) + } else { + true + } +} + +#[inline] +pub fn has_fast_bmi2() -> bool { + #[cfg(target_feature = "bmi2")] + { + static CACHE: OnceLock = OnceLock::new(); + return *CACHE.get_or_init(detect_fast_bmi2); + } + + false +} diff --git a/crates/polars-utils/src/signed_divmod.rs b/crates/polars-utils/src/floor_divmod.rs similarity index 66% rename from crates/polars-utils/src/signed_divmod.rs rename to crates/polars-utils/src/floor_divmod.rs index 188884d41f1da..14c02fc8d257c 100644 --- a/crates/polars-utils/src/signed_divmod.rs +++ b/crates/polars-utils/src/floor_divmod.rs @@ -1,16 +1,40 @@ -pub trait SignedDivMod: Sized { +pub trait FloorDivMod: Sized { // Returns the flooring division and associated modulo of lhs / rhs. // This is the same division / modulo combination as Python. // // Returns (0, 0) if other == 0. - fn wrapping_div_mod(self, other: Self) -> (Self, Self); + fn wrapping_floor_div_mod(self, other: Self) -> (Self, Self); +} + +macro_rules! impl_float_div_mod { + ($T:ty) => { + impl FloorDivMod for $T { + #[inline] + fn wrapping_floor_div_mod(self, other: Self) -> (Self, Self) { + let div = (self / other).floor(); + let mod_ = self - other * div; + (div, mod_) + } + } + }; +} + +macro_rules! impl_unsigned_div_mod { + ($T:ty) => { + impl FloorDivMod for $T { + #[inline] + fn wrapping_floor_div_mod(self, other: Self) -> (Self, Self) { + (self / other, self % other) + } + } + }; } macro_rules! impl_signed_div_mod { ($T:ty) => { - impl SignedDivMod for $T { + impl FloorDivMod for $T { #[inline] - fn wrapping_div_mod(self, other: Self) -> (Self, Self) { + fn wrapping_floor_div_mod(self, other: Self) -> (Self, Self) { if other == 0 { return (0, 0); } @@ -37,12 +61,20 @@ macro_rules! impl_signed_div_mod { }; } +impl_unsigned_div_mod!(u8); +impl_unsigned_div_mod!(u16); +impl_unsigned_div_mod!(u32); +impl_unsigned_div_mod!(u64); +impl_unsigned_div_mod!(u128); +impl_unsigned_div_mod!(usize); impl_signed_div_mod!(i8); impl_signed_div_mod!(i16); impl_signed_div_mod!(i32); impl_signed_div_mod!(i64); impl_signed_div_mod!(i128); impl_signed_div_mod!(isize); +impl_float_div_mod!(f32); +impl_float_div_mod!(f64); #[cfg(test)] mod test { @@ -63,7 +95,7 @@ mod test { (0, 0) }; - assert_eq!(lhs.wrapping_div_mod(rhs), ans); + assert_eq!(lhs.wrapping_floor_div_mod(rhs), ans); } } } diff --git a/crates/polars-utils/src/hashing.rs b/crates/polars-utils/src/hashing.rs index df44c5be7f855..12e59bf52f26a 100644 --- a/crates/polars-utils/src/hashing.rs +++ b/crates/polars-utils/src/hashing.rs @@ -46,7 +46,7 @@ impl<'a> PartialEq for BytesHash<'a> { } } -#[inline] +#[inline(always)] pub fn hash_to_partition(h: u64, n_partitions: usize) -> usize { // Assuming h is a 64-bit random number, we note that // h / 2^64 is almost a uniform random number in [0, 1), and thus @@ -86,6 +86,14 @@ impl_hash_partition_as_u64!(i16); impl_hash_partition_as_u64!(i32); impl_hash_partition_as_u64!(i64); +impl DirtyHash for i128 { + fn dirty_hash(&self) -> u64 { + (*self as u64) + .wrapping_mul(RANDOM_ODD) + .wrapping_add((*self >> 64) as u64) + } +} + impl<'a> DirtyHash for BytesHash<'a> { fn dirty_hash(&self) -> u64 { self.hash diff --git a/crates/polars-utils/src/index.rs b/crates/polars-utils/src/index.rs index 1815860383f49..1ca29d3947273 100644 --- a/crates/polars-utils/src/index.rs +++ b/crates/polars-utils/src/index.rs @@ -1,8 +1,73 @@ -use polars_error::{polars_bail, polars_ensure, PolarsResult}; +use std::fmt::{Debug, Formatter}; + +use polars_error::{polars_ensure, PolarsResult}; use crate::nulls::IsNull; use crate::slice::GetSaferUnchecked; -use crate::IdxSize; + +#[cfg(not(feature = "bigidx"))] +pub type IdxSize = u32; +#[cfg(feature = "bigidx")] +pub type IdxSize = u64; + +#[derive(Clone, Copy)] +#[repr(transparent)] +pub struct NullableIdxSize { + pub inner: IdxSize, +} + +impl PartialEq for NullableIdxSize { + fn eq(&self, other: &Self) -> bool { + self.inner == other.inner + } +} + +impl Eq for NullableIdxSize {} + +unsafe impl bytemuck::Zeroable for NullableIdxSize {} +unsafe impl bytemuck::AnyBitPattern for NullableIdxSize {} +unsafe impl bytemuck::NoUninit for NullableIdxSize {} + +impl Debug for NullableIdxSize { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.inner) + } +} + +impl NullableIdxSize { + #[inline(always)] + pub fn is_null_idx(&self) -> bool { + self.inner == IdxSize::MAX + } + + #[inline(always)] + pub const fn null() -> Self { + Self { + inner: IdxSize::MAX, + } + } + + #[inline(always)] + pub fn idx(&self) -> IdxSize { + self.inner + } + + #[inline(always)] + pub fn to_opt(&self) -> Option { + if self.is_null_idx() { + None + } else { + Some(self.idx()) + } + } +} + +impl From for NullableIdxSize { + #[inline(always)] + fn from(value: IdxSize) -> Self { + Self { inner: value } + } +} pub trait Bounded { fn len(&self) -> usize; @@ -116,15 +181,38 @@ impl_to_idx!(i64, i64); // Allows for 2^24 (~16M) chunks // Leaves 2^40 (~1T) rows per chunk -const CHUNK_BITS: u64 = 24; +const DEFAULT_CHUNK_BITS: u64 = 24; -#[derive(Clone, Copy, Debug)] -#[repr(C)] -pub struct ChunkId { +#[derive(Clone, Copy)] +#[repr(transparent)] +pub struct ChunkId { swizzled: u64, } -impl ChunkId { +pub type NullableChunkId = ChunkId; + +impl Debug for ChunkId { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + if self.is_null() { + write!(f, "NULL") + } else { + let (chunk, row) = self.extract(); + write!(f, "({chunk}, {row})") + } + } +} + +impl ChunkId { + #[inline(always)] + pub const fn null() -> Self { + Self { swizzled: u64::MAX } + } + + #[inline(always)] + pub fn is_null(&self) -> bool { + self.swizzled == u64::MAX + } + #[inline(always)] #[allow(clippy::unnecessary_cast)] pub fn store(chunk: IdxSize, row: IdxSize) -> Self { @@ -139,10 +227,15 @@ impl ChunkId { pub fn extract(self) -> (IdxSize, IdxSize) { let row = (self.swizzled >> CHUNK_BITS) as IdxSize; - const MASK: IdxSize = IdxSize::MAX << CHUNK_BITS; - let chunk = (self.swizzled as IdxSize) & !MASK; + let mask: IdxSize = IdxSize::MAX << CHUNK_BITS; + let chunk = (self.swizzled as IdxSize) & !mask; (chunk, row) } + + #[inline(always)] + pub fn inner_mut(&mut self) -> &mut u64 { + &mut self.swizzled + } } #[cfg(test)] @@ -154,7 +247,7 @@ mod test { let chunk = 213908; let row = 813457; - let ci = ChunkId::store(chunk, row); + let ci: ChunkId = ChunkId::store(chunk, row); let (c, r) = ci.extract(); assert_eq!(c, chunk); diff --git a/crates/polars-utils/src/iter/fallible.rs b/crates/polars-utils/src/iter/fallible.rs new file mode 100644 index 0000000000000..7ba544b13e18b --- /dev/null +++ b/crates/polars-utils/src/iter/fallible.rs @@ -0,0 +1,17 @@ +use std::error::Error; + +pub trait FallibleIterator: Iterator { + fn get_result(&mut self) -> Result<(), E>; +} + +pub trait FromFallibleIterator: Sized { + fn from_fallible_iter>(iter: F) -> Result; +} + +impl, E: Error> FromFallibleIterator for T { + fn from_fallible_iter>(mut iter: F) -> Result { + let out = T::from_iter(&mut iter); + iter.get_result()?; + Ok(out) + } +} diff --git a/crates/polars-utils/src/iter/mod.rs b/crates/polars-utils/src/iter/mod.rs index 408884f623d31..b3158416abb2a 100644 --- a/crates/polars-utils/src/iter/mod.rs +++ b/crates/polars-utils/src/iter/mod.rs @@ -1,5 +1,8 @@ mod enumerate_idx; +mod fallible; + pub use enumerate_idx::EnumerateIdxTrait; +pub use fallible::*; pub trait IntoIteratorCopied: IntoIterator { /// The type of the elements being iterated over. diff --git a/crates/polars-utils/src/lib.rs b/crates/polars-utils/src/lib.rs index 5ad52e2e4e721..575571b62985a 100644 --- a/crates/polars-utils/src/lib.rs +++ b/crates/polars-utils/src/lib.rs @@ -4,14 +4,16 @@ pub mod arena; pub mod atomic; pub mod cache; pub mod cell; +pub mod clmul; pub mod contention_pool; +pub mod cpuid; mod error; +pub mod floor_divmod; pub mod functions; pub mod hashing; pub mod idx_vec; pub mod mem; pub mod min_max; -pub mod signed_divmod; pub mod slice; pub mod sort; pub mod sync; @@ -22,11 +24,6 @@ pub mod unwrap; pub use functions::*; -#[cfg(not(feature = "bigidx"))] -pub type IdxSize = u32; -#[cfg(feature = "bigidx")] -pub type IdxSize = u64; - pub mod aliases; pub mod fmt; pub mod iter; @@ -40,5 +37,7 @@ pub mod index; pub mod io; pub mod nulls; pub mod ord; +pub mod partitioned; +pub use index::{IdxSize, NullableIdxSize}; pub use io::open_file; diff --git a/crates/polars-utils/src/partitioned.rs b/crates/polars-utils/src/partitioned.rs new file mode 100644 index 0000000000000..c23af9e95d761 --- /dev/null +++ b/crates/polars-utils/src/partitioned.rs @@ -0,0 +1,49 @@ +use hashbrown::hash_map::{HashMap, RawEntryBuilder, RawEntryBuilderMut}; + +use crate::hashing::hash_to_partition; +use crate::slice::GetSaferUnchecked; + +pub struct PartitionedHashMap { + inner: Vec>, +} + +impl PartitionedHashMap { + pub fn new(inner: Vec>) -> Self { + Self { inner } + } + + #[inline(always)] + pub fn raw_entry_mut(&mut self, h: u64) -> RawEntryBuilderMut<'_, K, V, S> { + self.raw_entry_and_partition_mut(h).0 + } + + #[inline(always)] + pub fn raw_entry(&self, h: u64) -> RawEntryBuilder<'_, K, V, S> { + self.raw_entry_and_partition(h).0 + } + + #[inline] + pub fn raw_entry_and_partition(&self, h: u64) -> (RawEntryBuilder<'_, K, V, S>, usize) { + let partition = hash_to_partition(h, self.inner.len()); + let current_table = unsafe { self.inner.get_unchecked_release(partition) }; + (current_table.raw_entry(), partition) + } + + #[inline] + pub fn raw_entry_and_partition_mut( + &mut self, + h: u64, + ) -> (RawEntryBuilderMut<'_, K, V, S>, usize) { + let partition = hash_to_partition(h, self.inner.len()); + let current_table = unsafe { self.inner.get_unchecked_release_mut(partition) }; + (current_table.raw_entry_mut(), partition) + } + + pub fn inner(&self) -> &[HashMap] { + self.inner.as_ref() + } + + pub fn inner_mut(&mut self) -> &mut Vec> { + &mut self.inner + } +} diff --git a/crates/polars-utils/src/slice.rs b/crates/polars-utils/src/slice.rs index 161fb4f25260b..30bdfa0f13027 100644 --- a/crates/polars-utils/src/slice.rs +++ b/crates/polars-utils/src/slice.rs @@ -121,3 +121,27 @@ impl Slice2Uninit for [T] { unsafe { std::slice::from_raw_parts(self.as_ptr() as *const MaybeUninit, self.len()) } } } + +// Loads a u64 from the given byteslice, as if it were padded with zeros. +#[inline] +pub fn load_padded_le_u64(bytes: &[u8]) -> u64 { + let len = bytes.len(); + if len >= 8 { + return u64::from_le_bytes(bytes[0..8].try_into().unwrap()); + } + + if len >= 4 { + let lo = u32::from_le_bytes(bytes[0..4].try_into().unwrap()); + let hi = u32::from_le_bytes(bytes[len - 4..len].try_into().unwrap()); + return (lo as u64) | ((hi as u64) << (8 * (len - 4))); + } + + if len == 0 { + return 0; + } + + let lo = bytes[0] as u64; + let mid = (bytes[len / 2] as u64) << (8 * (len / 2)); + let hi = (bytes[len - 1] as u64) << (8 * (len - 1)); + lo | mid | hi +} diff --git a/crates/polars-utils/src/total_ord.rs b/crates/polars-utils/src/total_ord.rs index 8dac484d5d962..5b9af065779fd 100644 --- a/crates/polars-utils/src/total_ord.rs +++ b/crates/polars-utils/src/total_ord.rs @@ -3,6 +3,9 @@ use std::hash::{Hash, Hasher}; use bytemuck::TransparentWrapper; +use crate::hashing::{BytesHash, DirtyHash}; +use crate::nulls::IsNull; + /// Converts an f32 into a canonical form, where -0 == 0 and all NaNs map to /// the same value. pub fn canonical_f32(x: f32) -> f32 { @@ -32,7 +35,7 @@ pub fn canonical_f64(x: f64) -> f64 { pub trait TotalEq { fn tot_eq(&self, other: &Self) -> bool; - #[inline(always)] + #[inline] fn tot_ne(&self, other: &Self) -> bool { !(self.tot_eq(other)) } @@ -43,22 +46,22 @@ pub trait TotalEq { pub trait TotalOrd: TotalEq { fn tot_cmp(&self, other: &Self) -> Ordering; - #[inline(always)] + #[inline] fn tot_lt(&self, other: &Self) -> bool { self.tot_cmp(other) == Ordering::Less } - #[inline(always)] + #[inline] fn tot_gt(&self, other: &Self) -> bool { self.tot_cmp(other) == Ordering::Greater } - #[inline(always)] + #[inline] fn tot_le(&self, other: &Self) -> bool { self.tot_cmp(other) != Ordering::Greater } - #[inline(always)] + #[inline] fn tot_ge(&self, other: &Self) -> bool { self.tot_cmp(other) != Ordering::Less } @@ -87,46 +90,46 @@ pub struct TotalOrdWrap(pub T); unsafe impl TransparentWrapper for TotalOrdWrap {} impl PartialOrd for TotalOrdWrap { - #[inline(always)] + #[inline] fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } - #[inline(always)] + #[inline] fn lt(&self, other: &Self) -> bool { self.0.tot_lt(&other.0) } - #[inline(always)] + #[inline] fn le(&self, other: &Self) -> bool { self.0.tot_le(&other.0) } - #[inline(always)] + #[inline] fn gt(&self, other: &Self) -> bool { self.0.tot_gt(&other.0) } - #[inline(always)] + #[inline] fn ge(&self, other: &Self) -> bool { self.0.tot_ge(&other.0) } } impl Ord for TotalOrdWrap { - #[inline(always)] + #[inline] fn cmp(&self, other: &Self) -> Ordering { self.0.tot_cmp(&other.0) } } impl PartialEq for TotalOrdWrap { - #[inline(always)] + #[inline] fn eq(&self, other: &Self) -> bool { self.0.tot_eq(&other.0) } - #[inline(always)] + #[inline] #[allow(clippy::partialeq_ne_impl)] fn ne(&self, other: &Self) -> bool { self.0.tot_ne(&other.0) @@ -136,12 +139,14 @@ impl PartialEq for TotalOrdWrap { impl Eq for TotalOrdWrap {} impl Hash for TotalOrdWrap { + #[inline] fn hash(&self, state: &mut H) { self.0.tot_hash(state); } } impl Clone for TotalOrdWrap { + #[inline] fn clone(&self) -> Self { Self(self.0.clone()) } @@ -149,48 +154,85 @@ impl Clone for TotalOrdWrap { impl Copy for TotalOrdWrap {} +impl IsNull for TotalOrdWrap { + const HAS_NULLS: bool = T::HAS_NULLS; + type Inner = T::Inner; + + #[inline] + fn is_null(&self) -> bool { + self.0.is_null() + } + + #[inline] + fn unwrap_inner(self) -> Self::Inner { + self.0.unwrap_inner() + } +} + +impl DirtyHash for f32 { + #[inline] + fn dirty_hash(&self) -> u64 { + canonical_f32(*self).to_bits().dirty_hash() + } +} + +impl DirtyHash for f64 { + #[inline] + fn dirty_hash(&self) -> u64 { + canonical_f64(*self).to_bits().dirty_hash() + } +} + +impl DirtyHash for TotalOrdWrap { + #[inline] + fn dirty_hash(&self) -> u64 { + self.0.dirty_hash() + } +} + macro_rules! impl_trivial_total { ($T: ty) => { impl TotalEq for $T { - #[inline(always)] + #[inline] fn tot_eq(&self, other: &Self) -> bool { self == other } - #[inline(always)] + #[inline] fn tot_ne(&self, other: &Self) -> bool { self != other } } impl TotalOrd for $T { - #[inline(always)] + #[inline] fn tot_cmp(&self, other: &Self) -> Ordering { self.cmp(other) } - #[inline(always)] + #[inline] fn tot_lt(&self, other: &Self) -> bool { self < other } - #[inline(always)] + #[inline] fn tot_gt(&self, other: &Self) -> bool { self > other } - #[inline(always)] + #[inline] fn tot_le(&self, other: &Self) -> bool { self <= other } - #[inline(always)] + #[inline] fn tot_ge(&self, other: &Self) -> bool { self >= other } } impl TotalHash for $T { + #[inline] fn tot_hash(&self, state: &mut H) where H: Hasher, @@ -224,7 +266,7 @@ impl_trivial_total!(String); macro_rules! impl_float_eq_ord { ($T:ty) => { impl TotalEq for $T { - #[inline(always)] + #[inline] fn tot_eq(&self, other: &Self) -> bool { if self.is_nan() { other.is_nan() @@ -235,7 +277,7 @@ macro_rules! impl_float_eq_ord { } impl TotalOrd for $T { - #[inline(always)] + #[inline] fn tot_cmp(&self, other: &Self) -> Ordering { if self.tot_lt(other) { Ordering::Less @@ -246,22 +288,22 @@ macro_rules! impl_float_eq_ord { } } - #[inline(always)] + #[inline] fn tot_lt(&self, other: &Self) -> bool { !self.tot_ge(other) } - #[inline(always)] + #[inline] fn tot_gt(&self, other: &Self) -> bool { other.tot_lt(self) } - #[inline(always)] + #[inline] fn tot_le(&self, other: &Self) -> bool { other.tot_ge(self) } - #[inline(always)] + #[inline] fn tot_ge(&self, other: &Self) -> bool { // We consider all NaNs equal, and NaN is the largest possible // value. Thus if self is NaN we always return true. Otherwise @@ -278,6 +320,7 @@ impl_float_eq_ord!(f32); impl_float_eq_ord!(f64); impl TotalHash for f32 { + #[inline] fn tot_hash(&self, state: &mut H) where H: Hasher, @@ -287,6 +330,7 @@ impl TotalHash for f32 { } impl TotalHash for f64 { + #[inline] fn tot_hash(&self, state: &mut H) where H: Hasher, @@ -297,7 +341,7 @@ impl TotalHash for f64 { // Blanket implementations. impl TotalEq for Option { - #[inline(always)] + #[inline] fn tot_eq(&self, other: &Self) -> bool { match (self, other) { (None, None) => true, @@ -306,7 +350,7 @@ impl TotalEq for Option { } } - #[inline(always)] + #[inline] fn tot_ne(&self, other: &Self) -> bool { match (self, other) { (None, None) => false, @@ -317,7 +361,7 @@ impl TotalEq for Option { } impl TotalOrd for Option { - #[inline(always)] + #[inline] fn tot_cmp(&self, other: &Self) -> Ordering { match (self, other) { (None, None) => Ordering::Equal, @@ -327,7 +371,7 @@ impl TotalOrd for Option { } } - #[inline(always)] + #[inline] fn tot_lt(&self, other: &Self) -> bool { match (self, other) { (None, Some(_)) => true, @@ -336,12 +380,12 @@ impl TotalOrd for Option { } } - #[inline(always)] + #[inline] fn tot_gt(&self, other: &Self) -> bool { other.tot_lt(self) } - #[inline(always)] + #[inline] fn tot_le(&self, other: &Self) -> bool { match (self, other) { (Some(_), None) => false, @@ -350,7 +394,7 @@ impl TotalOrd for Option { } } - #[inline(always)] + #[inline] fn tot_ge(&self, other: &Self) -> bool { other.tot_le(self) } @@ -369,18 +413,19 @@ impl TotalHash for Option { } impl TotalEq for &T { - #[inline(always)] + #[inline] fn tot_eq(&self, other: &Self) -> bool { (*self).tot_eq(*other) } - #[inline(always)] + #[inline] fn tot_ne(&self, other: &Self) -> bool { (*self).tot_ne(*other) } } impl TotalHash for &T { + #[inline] fn tot_hash(&self, state: &mut H) where H: Hasher, @@ -402,3 +447,159 @@ impl TotalOrd for (T, U) { .then_with(|| self.1.tot_cmp(&other.1)) } } + +impl<'a> TotalHash for BytesHash<'a> { + #[inline] + fn tot_hash(&self, state: &mut H) + where + H: Hasher, + { + self.hash(state) + } +} + +impl<'a> TotalEq for BytesHash<'a> { + #[inline] + fn tot_eq(&self, other: &Self) -> bool { + self == other + } +} + +/// This elides creating a [`TotalOrdWrap`] for types that don't need it. +pub trait ToTotalOrd { + type TotalOrdItem; + type SourceItem; + + fn to_total_ord(&self) -> Self::TotalOrdItem; + + fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem; +} + +macro_rules! impl_to_total_ord_identity { + ($T: ty) => { + impl ToTotalOrd for $T { + type TotalOrdItem = $T; + type SourceItem = $T; + + #[inline] + fn to_total_ord(&self) -> Self::TotalOrdItem { + self.clone() + } + + #[inline] + fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem { + ord_item + } + } + }; +} + +impl_to_total_ord_identity!(bool); +impl_to_total_ord_identity!(u8); +impl_to_total_ord_identity!(u16); +impl_to_total_ord_identity!(u32); +impl_to_total_ord_identity!(u64); +impl_to_total_ord_identity!(u128); +impl_to_total_ord_identity!(usize); +impl_to_total_ord_identity!(i8); +impl_to_total_ord_identity!(i16); +impl_to_total_ord_identity!(i32); +impl_to_total_ord_identity!(i64); +impl_to_total_ord_identity!(i128); +impl_to_total_ord_identity!(isize); +impl_to_total_ord_identity!(char); +impl_to_total_ord_identity!(String); + +macro_rules! impl_to_total_ord_lifetimed_ref_identity { + ($T: ty) => { + impl<'a> ToTotalOrd for &'a $T { + type TotalOrdItem = &'a $T; + type SourceItem = &'a $T; + + #[inline] + fn to_total_ord(&self) -> Self::TotalOrdItem { + *self + } + + #[inline] + fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem { + ord_item + } + } + }; +} + +impl_to_total_ord_lifetimed_ref_identity!(str); +impl_to_total_ord_lifetimed_ref_identity!([u8]); + +macro_rules! impl_to_total_ord_wrapped { + ($T: ty) => { + impl ToTotalOrd for $T { + type TotalOrdItem = TotalOrdWrap<$T>; + type SourceItem = $T; + + #[inline] + fn to_total_ord(&self) -> Self::TotalOrdItem { + TotalOrdWrap(self.clone()) + } + + #[inline] + fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem { + ord_item.0 + } + } + }; +} + +impl_to_total_ord_wrapped!(f32); +impl_to_total_ord_wrapped!(f64); + +/// This is safe without needing to map the option value to TotalOrdWrap, since +/// for example: +/// `TotalOrdWrap>` implements `Eq + Hash`, iff: +/// `Option` implements `TotalEq + TotalHash`, iff: +/// `T` implements `TotalEq + TotalHash` +impl ToTotalOrd for Option { + type TotalOrdItem = TotalOrdWrap>; + type SourceItem = Option; + + #[inline] + fn to_total_ord(&self) -> Self::TotalOrdItem { + TotalOrdWrap(*self) + } + + #[inline] + fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem { + ord_item.0 + } +} + +impl ToTotalOrd for &T { + type TotalOrdItem = T::TotalOrdItem; + type SourceItem = T::SourceItem; + + #[inline] + fn to_total_ord(&self) -> Self::TotalOrdItem { + (*self).to_total_ord() + } + + #[inline] + fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem { + T::peel_total_ord(ord_item) + } +} + +impl<'a> ToTotalOrd for BytesHash<'a> { + type TotalOrdItem = BytesHash<'a>; + type SourceItem = BytesHash<'a>; + + #[inline] + fn to_total_ord(&self) -> Self::TotalOrdItem { + *self + } + + #[inline] + fn peel_total_ord(ord_item: Self::TotalOrdItem) -> Self::SourceItem { + ord_item + } +} diff --git a/crates/polars/Cargo.toml b/crates/polars/Cargo.toml index c763b69a4bad4..444d7b7cc947c 100644 --- a/crates/polars/Cargo.toml +++ b/crates/polars/Cargo.toml @@ -11,17 +11,31 @@ repository = { workspace = true } description = "DataFrame library based on Apache Arrow" [dependencies] +arrow = { workspace = true } polars-core = { workspace = true, features = ["algorithm_group_by"] } +polars-error = { workspace = true } polars-io = { workspace = true, optional = true } polars-lazy = { workspace = true, optional = true } polars-ops = { workspace = true, optional = true } +polars-parquet = { workspace = true } polars-plan = { workspace = true, optional = true } polars-sql = { workspace = true, optional = true } polars-time = { workspace = true, optional = true } +polars-utils = { workspace = true } [dev-dependencies] ahash = { workspace = true } +apache-avro = { version = "0.16", features = ["snappy"] } +avro-schema = { workspace = true, features = ["async"] } +either = { workspace = true } +ethnum = "1" +futures = { workspace = true } +# used to run formal property testing +proptest = { version = "1", default_features = false, features = ["std"] } rand = { workspace = true } +# used to test async readers +tokio = { workspace = true, features = ["macros", "rt", "fs", "io-util"] } +tokio-util = { workspace = true, features = ["compat"] } [build-dependencies] version_check = { workspace = true } @@ -112,7 +126,7 @@ abs = ["polars-ops/abs", "polars-lazy?/abs"] approx_unique = ["polars-lazy?/approx_unique", "polars-ops/approx_unique"] arg_where = ["polars-lazy?/arg_where"] array_any_all = ["polars-lazy?/array_any_all", "dtype-array"] -asof_join = ["polars-core/asof_join", "polars-lazy?/asof_join", "polars-ops/asof_join"] +asof_join = ["polars-lazy?/asof_join", "polars-ops/asof_join"] bigidx = ["polars-core/bigidx", "polars-lazy?/bigidx", "polars-ops/big_idx"] binary_encoding = ["polars-ops/binary_encoding", "polars-lazy?/binary_encoding", "polars-sql?/binary_encoding"] checked_arithmetic = ["polars-core/checked_arithmetic"] diff --git a/crates/polars/src/lib.rs b/crates/polars/src/lib.rs index 468a9ed949f39..3f0e84b3f9b52 100644 --- a/crates/polars/src/lib.rs +++ b/crates/polars/src/lib.rs @@ -148,8 +148,6 @@ //! //! Understanding polars expressions is most important when starting with the polars library. Read more //! about them in the [user guide](https://docs.pola.rs/user-guide/concepts/expressions). -//! Though the examples given there are in python. The expressions API is almost identical and the -//! the read should certainly be valuable to rust users as well. //! //! ### Eager //! Read more in the pages of the following data structures /traits. @@ -225,7 +223,7 @@ //! - `rows` - Create [`DataFrame`] from rows and extract rows from [`DataFrame`]s. //! And activates `pivot` and `transpose` operations //! - `asof_join` - Join ASOF, to join on nearest keys instead of exact equality match. -//! - `cross_join` - Create the cartesian product of two [`DataFrame`]s. +//! - `cross_join` - Create the Cartesian product of two [`DataFrame`]s. //! - `semi_anti_join` - SEMI and ANTI joins. //! - `group_by_list` - Allow group_by operation on keys of type List. //! - `row_hash` - Utility to hash [`DataFrame`] rows to [`UInt64Chunked`] diff --git a/crates/polars/tests/it/arrow/array/binary/mod.rs b/crates/polars/tests/it/arrow/array/binary/mod.rs new file mode 100644 index 0000000000000..3a44b67cbca99 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/binary/mod.rs @@ -0,0 +1,214 @@ +use arrow::array::{Array, BinaryArray}; +use arrow::bitmap::Bitmap; +use arrow::buffer::Buffer; +use arrow::datatypes::ArrowDataType; +use arrow::offset::OffsetsBuffer; +use polars_error::PolarsResult; + +mod mutable; +mod mutable_values; +mod to_mutable; + +#[test] +fn basics() { + let data = vec![Some(b"hello".to_vec()), None, Some(b"hello2".to_vec())]; + + let array: BinaryArray = data.into_iter().collect(); + + assert_eq!(array.value(0), b"hello"); + assert_eq!(array.value(1), b""); + assert_eq!(array.value(2), b"hello2"); + assert_eq!(unsafe { array.value_unchecked(2) }, b"hello2"); + assert_eq!(array.values().as_slice(), b"hellohello2"); + assert_eq!(array.offsets().as_slice(), &[0, 5, 5, 11]); + assert_eq!( + array.validity(), + Some(&Bitmap::from_u8_slice([0b00000101], 3)) + ); + assert!(array.is_valid(0)); + assert!(!array.is_valid(1)); + assert!(array.is_valid(2)); + + let array2 = BinaryArray::::new( + ArrowDataType::Binary, + array.offsets().clone(), + array.values().clone(), + array.validity().cloned(), + ); + assert_eq!(array, array2); + + let array = array.sliced(1, 2); + assert_eq!(array.value(0), b""); + assert_eq!(array.value(1), b"hello2"); + // note how this keeps everything: the offsets were sliced + assert_eq!(array.values().as_slice(), b"hellohello2"); + assert_eq!(array.offsets().as_slice(), &[5, 5, 11]); +} + +#[test] +fn empty() { + let array = BinaryArray::::new_empty(ArrowDataType::Binary); + assert_eq!(array.values().as_slice(), b""); + assert_eq!(array.offsets().as_slice(), &[0]); + assert_eq!(array.validity(), None); +} + +#[test] +fn from() { + let array = BinaryArray::::from([Some(b"hello".as_ref()), Some(b" ".as_ref()), None]); + + let a = array.validity().unwrap(); + assert_eq!(a, &Bitmap::from([true, true, false])); +} + +#[test] +fn from_trusted_len_iter() { + let iter = std::iter::repeat(b"hello").take(2).map(Some); + let a = BinaryArray::::from_trusted_len_iter(iter); + assert_eq!(a.len(), 2); +} + +#[test] +fn try_from_trusted_len_iter() { + let iter = std::iter::repeat(b"hello".as_ref()) + .take(2) + .map(Some) + .map(PolarsResult::Ok); + let a = BinaryArray::::try_from_trusted_len_iter(iter).unwrap(); + assert_eq!(a.len(), 2); +} + +#[test] +fn from_iter() { + let iter = std::iter::repeat(b"hello").take(2).map(Some); + let a: BinaryArray = iter.collect(); + assert_eq!(a.len(), 2); +} + +#[test] +fn with_validity() { + let array = BinaryArray::::from([Some(b"hello".as_ref()), Some(b" ".as_ref()), None]); + + let array = array.with_validity(None); + + let a = array.validity(); + assert_eq!(a, None); +} + +#[test] +#[should_panic] +fn wrong_offsets() { + let offsets = vec![0, 5, 4].try_into().unwrap(); // invalid offsets + let values = Buffer::from(b"abbbbb".to_vec()); + BinaryArray::::new(ArrowDataType::Binary, offsets, values, None); +} + +#[test] +#[should_panic] +fn wrong_data_type() { + let offsets = vec![0, 4].try_into().unwrap(); + let values = Buffer::from(b"abbb".to_vec()); + BinaryArray::::new(ArrowDataType::Int8, offsets, values, None); +} + +#[test] +#[should_panic] +fn value_with_wrong_offsets_panics() { + let offsets = vec![0, 10, 11, 4].try_into().unwrap(); + let values = Buffer::from(b"abbb".to_vec()); + // the 10-11 is not checked + let array = BinaryArray::::new(ArrowDataType::Binary, offsets, values, None); + + // but access is still checked (and panics) + // without checks, this would result in reading beyond bounds + array.value(0); +} + +#[test] +#[should_panic] +fn index_out_of_bounds_panics() { + let offsets = vec![0, 1, 2, 4].try_into().unwrap(); + let values = Buffer::from(b"abbb".to_vec()); + let array = BinaryArray::::new(ArrowDataType::Utf8, offsets, values, None); + + array.value(3); +} + +#[test] +#[should_panic] +fn value_unchecked_with_wrong_offsets_panics() { + let offsets = vec![0, 10, 11, 4].try_into().unwrap(); + let values = Buffer::from(b"abbb".to_vec()); + // the 10-11 is not checked + let array = BinaryArray::::new(ArrowDataType::Binary, offsets, values, None); + + // but access is still checked (and panics) + // without checks, this would result in reading beyond bounds, + // even if `0` is in bounds + unsafe { array.value_unchecked(0) }; +} + +#[test] +fn debug() { + let array = BinaryArray::::from([Some([1, 2].as_ref()), Some(&[]), None]); + + assert_eq!(format!("{array:?}"), "BinaryArray[[1, 2], [], None]"); +} + +#[test] +fn into_mut_1() { + let offsets = vec![0, 1].try_into().unwrap(); + let values = Buffer::from(b"a".to_vec()); + let a = values.clone(); // cloned values + assert_eq!(a, values); + let array = BinaryArray::::new(ArrowDataType::Binary, offsets, values, None); + assert!(array.into_mut().is_left()); +} + +#[test] +fn into_mut_2() { + let offsets: OffsetsBuffer = vec![0, 1].try_into().unwrap(); + let values = Buffer::from(b"a".to_vec()); + let a = offsets.clone(); // cloned offsets + assert_eq!(a, offsets); + let array = BinaryArray::::new(ArrowDataType::Binary, offsets, values, None); + assert!(array.into_mut().is_left()); +} + +#[test] +fn into_mut_3() { + let offsets = vec![0, 1].try_into().unwrap(); + let values = Buffer::from(b"a".to_vec()); + let validity = Some([true].into()); + let a = validity.clone(); // cloned validity + assert_eq!(a, validity); + let array = BinaryArray::::new(ArrowDataType::Binary, offsets, values, validity); + assert!(array.into_mut().is_left()); +} + +#[test] +fn into_mut_4() { + let offsets = vec![0, 1].try_into().unwrap(); + let values = Buffer::from(b"a".to_vec()); + let validity = Some([true].into()); + let array = BinaryArray::::new(ArrowDataType::Binary, offsets, values, validity); + assert!(array.into_mut().is_right()); +} + +#[test] +fn rev_iter() { + let array = BinaryArray::::from([Some("hello".as_bytes()), Some(" ".as_bytes()), None]); + + assert_eq!( + array.into_iter().rev().collect::>(), + vec![None, Some(" ".as_bytes()), Some("hello".as_bytes())] + ); +} + +#[test] +fn iter_nth() { + let array = BinaryArray::::from([Some("hello"), Some(" "), None]); + + assert_eq!(array.iter().nth(1), Some(Some(" ".as_bytes()))); + assert_eq!(array.iter().nth(10), None); +} diff --git a/crates/polars/tests/it/arrow/array/binary/mutable.rs b/crates/polars/tests/it/arrow/array/binary/mutable.rs new file mode 100644 index 0000000000000..d57deb22faa54 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/binary/mutable.rs @@ -0,0 +1,215 @@ +use std::ops::Deref; + +use arrow::array::{BinaryArray, MutableArray, MutableBinaryArray, TryExtendFromSelf}; +use arrow::bitmap::Bitmap; +use polars_error::PolarsError; + +#[test] +fn new() { + assert_eq!(MutableBinaryArray::::new().len(), 0); + + let a = MutableBinaryArray::::with_capacity(2); + assert_eq!(a.len(), 0); + assert!(a.offsets().capacity() >= 2); + assert_eq!(a.values().capacity(), 0); + + let a = MutableBinaryArray::::with_capacities(2, 60); + assert_eq!(a.len(), 0); + assert!(a.offsets().capacity() >= 2); + assert!(a.values().capacity() >= 60); +} + +#[test] +fn from_iter() { + let iter = (0..3u8).map(|x| Some(vec![x; x as usize])); + let a: MutableBinaryArray = iter.clone().collect(); + assert_eq!(a.values().deref(), &[1u8, 2, 2]); + assert_eq!(a.offsets().as_slice(), &[0, 0, 1, 3]); + assert_eq!(a.validity(), None); + + let a = unsafe { MutableBinaryArray::::from_trusted_len_iter_unchecked(iter) }; + assert_eq!(a.values().deref(), &[1u8, 2, 2]); + assert_eq!(a.offsets().as_slice(), &[0, 0, 1, 3]); + assert_eq!(a.validity(), None); +} + +#[test] +fn from_trusted_len_iter() { + let data = [vec![0; 0], vec![1; 1], vec![2; 2]]; + let a: MutableBinaryArray = data.iter().cloned().map(Some).collect(); + assert_eq!(a.values().deref(), &[1u8, 2, 2]); + assert_eq!(a.offsets().as_slice(), &[0, 0, 1, 3]); + assert_eq!(a.validity(), None); + + let a = MutableBinaryArray::::from_trusted_len_iter(data.iter().cloned().map(Some)); + assert_eq!(a.values().deref(), &[1u8, 2, 2]); + assert_eq!(a.offsets().as_slice(), &[0, 0, 1, 3]); + assert_eq!(a.validity(), None); + + let a = MutableBinaryArray::::try_from_trusted_len_iter::( + data.iter().cloned().map(Some).map(Ok), + ) + .unwrap(); + assert_eq!(a.values().deref(), &[1u8, 2, 2]); + assert_eq!(a.offsets().as_slice(), &[0, 0, 1, 3]); + assert_eq!(a.validity(), None); + + let a = MutableBinaryArray::::from_trusted_len_values_iter(data.iter().cloned()); + assert_eq!(a.values().deref(), &[1u8, 2, 2]); + assert_eq!(a.offsets().as_slice(), &[0, 0, 1, 3]); + assert_eq!(a.validity(), None); +} + +#[test] +fn push_null() { + let mut array = MutableBinaryArray::::new(); + array.push::<&str>(None); + + let array: BinaryArray = array.into(); + assert_eq!(array.validity(), Some(&Bitmap::from([false]))); +} + +#[test] +fn pop() { + let mut a = MutableBinaryArray::::new(); + a.push(Some(b"first")); + a.push(Some(b"second")); + a.push::>(None); + a.push_null(); + + assert_eq!(a.pop(), None); + assert_eq!(a.len(), 3); + assert_eq!(a.pop(), None); + assert_eq!(a.len(), 2); + assert_eq!(a.pop(), Some(b"second".to_vec())); + assert_eq!(a.len(), 1); + assert_eq!(a.pop(), Some(b"first".to_vec())); + assert_eq!(a.len(), 0); + assert_eq!(a.pop(), None); + assert_eq!(a.len(), 0); +} + +#[test] +fn pop_all_some() { + let mut a = MutableBinaryArray::::new(); + a.push(Some(b"first")); + a.push(Some(b"second")); + a.push(Some(b"third")); + a.push(Some(b"fourth")); + + for _ in 0..4 { + a.push(Some(b"aaaa")); + } + + a.push(Some(b"bbbb")); + + assert_eq!(a.pop(), Some(b"bbbb".to_vec())); + assert_eq!(a.pop(), Some(b"aaaa".to_vec())); + assert_eq!(a.pop(), Some(b"aaaa".to_vec())); + assert_eq!(a.pop(), Some(b"aaaa".to_vec())); + assert_eq!(a.len(), 5); + assert_eq!(a.pop(), Some(b"aaaa".to_vec())); + assert_eq!(a.pop(), Some(b"fourth".to_vec())); + assert_eq!(a.pop(), Some(b"third".to_vec())); + assert_eq!(a.pop(), Some(b"second".to_vec())); + assert_eq!(a.pop(), Some(b"first".to_vec())); + assert!(a.is_empty()); + assert_eq!(a.pop(), None); +} + +#[test] +fn extend_trusted_len_values() { + let mut array = MutableBinaryArray::::new(); + + array.extend_trusted_len_values(vec![b"first".to_vec(), b"second".to_vec()].into_iter()); + array.extend_trusted_len_values(vec![b"third".to_vec()].into_iter()); + array.extend_trusted_len(vec![None, Some(b"fourth".to_vec())].into_iter()); + + let array: BinaryArray = array.into(); + + assert_eq!(array.values().as_slice(), b"firstsecondthirdfourth"); + assert_eq!(array.offsets().as_slice(), &[0, 5, 11, 16, 16, 22]); + assert_eq!( + array.validity(), + Some(&Bitmap::from_u8_slice([0b00010111], 5)) + ); +} + +#[test] +fn extend_trusted_len() { + let mut array = MutableBinaryArray::::new(); + + array.extend_trusted_len(vec![Some(b"first".to_vec()), Some(b"second".to_vec())].into_iter()); + array.extend_trusted_len(vec![None, Some(b"third".to_vec())].into_iter()); + + let array: BinaryArray = array.into(); + + assert_eq!(array.values().as_slice(), b"firstsecondthird"); + assert_eq!(array.offsets().as_slice(), &[0, 5, 11, 11, 16]); + assert_eq!( + array.validity(), + Some(&Bitmap::from_u8_slice([0b00001011], 4)) + ); +} + +#[test] +fn extend_from_self() { + let mut a = MutableBinaryArray::::from([Some(b"aa"), None]); + + a.try_extend_from_self(&a.clone()).unwrap(); + + assert_eq!( + a, + MutableBinaryArray::::from([Some(b"aa"), None, Some(b"aa"), None]) + ); +} + +#[test] +fn test_set_validity() { + let mut array = MutableBinaryArray::::new(); + array.push(Some(b"first")); + array.push(Some(b"second")); + array.push(Some(b"third")); + array.set_validity(Some([false, false, true].into())); + + assert!(!array.is_valid(0)); + assert!(!array.is_valid(1)); + assert!(array.is_valid(2)); +} + +#[test] +fn test_apply_validity() { + let mut array = MutableBinaryArray::::new(); + array.push(Some(b"first")); + array.push(Some(b"second")); + array.push(Some(b"third")); + array.set_validity(Some([true, true, true].into())); + + array.apply_validity(|mut mut_bitmap| { + mut_bitmap.set(1, false); + mut_bitmap.set(2, false); + mut_bitmap + }); + + assert!(array.is_valid(0)); + assert!(!array.is_valid(1)); + assert!(!array.is_valid(2)); +} + +#[test] +fn test_apply_validity_with_no_validity_inited() { + let mut array = MutableBinaryArray::::new(); + array.push(Some(b"first")); + array.push(Some(b"second")); + array.push(Some(b"third")); + + array.apply_validity(|mut mut_bitmap| { + mut_bitmap.set(1, false); + mut_bitmap.set(2, false); + mut_bitmap + }); + + assert!(array.is_valid(0)); + assert!(array.is_valid(1)); + assert!(array.is_valid(2)); +} diff --git a/crates/polars/tests/it/arrow/array/binary/mutable_values.rs b/crates/polars/tests/it/arrow/array/binary/mutable_values.rs new file mode 100644 index 0000000000000..c9e4f1da3bbec --- /dev/null +++ b/crates/polars/tests/it/arrow/array/binary/mutable_values.rs @@ -0,0 +1,101 @@ +use arrow::array::{MutableArray, MutableBinaryValuesArray}; +use arrow::datatypes::ArrowDataType; + +#[test] +fn capacity() { + let mut b = MutableBinaryValuesArray::::with_capacity(100); + + assert_eq!(b.values().capacity(), 0); + assert!(b.offsets().capacity() >= 100); + b.shrink_to_fit(); + assert!(b.offsets().capacity() < 100); +} + +#[test] +fn offsets_must_be_in_bounds() { + let offsets = vec![0, 10].try_into().unwrap(); + let values = b"abbbbb".to_vec(); + assert!( + MutableBinaryValuesArray::::try_new(ArrowDataType::Binary, offsets, values).is_err() + ); +} + +#[test] +fn data_type_must_be_consistent() { + let offsets = vec![0, 4].try_into().unwrap(); + let values = b"abbb".to_vec(); + assert!( + MutableBinaryValuesArray::::try_new(ArrowDataType::Int32, offsets, values).is_err() + ); +} + +#[test] +fn as_box() { + let offsets = vec![0, 2].try_into().unwrap(); + let values = b"ab".to_vec(); + let mut b = + MutableBinaryValuesArray::::try_new(ArrowDataType::Binary, offsets, values).unwrap(); + let _ = b.as_box(); +} + +#[test] +fn as_arc() { + let offsets = vec![0, 2].try_into().unwrap(); + let values = b"ab".to_vec(); + let mut b = + MutableBinaryValuesArray::::try_new(ArrowDataType::Binary, offsets, values).unwrap(); + let _ = b.as_arc(); +} + +#[test] +fn extend_trusted_len() { + let offsets = vec![0, 2].try_into().unwrap(); + let values = b"ab".to_vec(); + let mut b = + MutableBinaryValuesArray::::try_new(ArrowDataType::Binary, offsets, values).unwrap(); + b.extend_trusted_len(vec!["a", "b"].into_iter()); + + let offsets = vec![0, 2, 3, 4].try_into().unwrap(); + let values = b"abab".to_vec(); + assert_eq!( + b.as_box(), + MutableBinaryValuesArray::::try_new(ArrowDataType::Binary, offsets, values) + .unwrap() + .as_box() + ) +} + +#[test] +fn from_trusted_len() { + let mut b = MutableBinaryValuesArray::::from_trusted_len_iter(vec!["a", "b"].into_iter()); + + let offsets = vec![0, 1, 2].try_into().unwrap(); + let values = b"ab".to_vec(); + assert_eq!( + b.as_box(), + MutableBinaryValuesArray::::try_new(ArrowDataType::Binary, offsets, values) + .unwrap() + .as_box() + ) +} + +#[test] +fn extend_from_iter() { + let offsets = vec![0, 2].try_into().unwrap(); + let values = b"ab".to_vec(); + let mut b = + MutableBinaryValuesArray::::try_new(ArrowDataType::Binary, offsets, values).unwrap(); + b.extend_trusted_len(vec!["a", "b"].into_iter()); + + let a = b.clone(); + b.extend_trusted_len(a.iter()); + + let offsets = vec![0, 2, 3, 4, 6, 7, 8].try_into().unwrap(); + let values = b"abababab".to_vec(); + assert_eq!( + b.as_box(), + MutableBinaryValuesArray::::try_new(ArrowDataType::Binary, offsets, values) + .unwrap() + .as_box() + ) +} diff --git a/crates/polars/tests/it/arrow/array/binary/to_mutable.rs b/crates/polars/tests/it/arrow/array/binary/to_mutable.rs new file mode 100644 index 0000000000000..8f07d3a166b33 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/binary/to_mutable.rs @@ -0,0 +1,70 @@ +use arrow::array::BinaryArray; +use arrow::bitmap::Bitmap; +use arrow::buffer::Buffer; +use arrow::datatypes::ArrowDataType; + +#[test] +fn not_shared() { + let array = BinaryArray::::from([Some("hello"), Some(" "), None]); + assert!(array.into_mut().is_right()); +} + +#[test] +#[allow(clippy::redundant_clone)] +fn shared_validity() { + let validity = Bitmap::from([true]); + let array = BinaryArray::::new( + ArrowDataType::Binary, + vec![0, 1].try_into().unwrap(), + b"a".to_vec().into(), + Some(validity.clone()), + ); + assert!(array.into_mut().is_left()) +} + +#[test] +#[allow(clippy::redundant_clone)] +fn shared_values() { + let values: Buffer = b"a".to_vec().into(); + let array = BinaryArray::::new( + ArrowDataType::Binary, + vec![0, 1].try_into().unwrap(), + values.clone(), + Some(Bitmap::from([true])), + ); + assert!(array.into_mut().is_left()) +} + +#[test] +#[allow(clippy::redundant_clone)] +fn shared_offsets_values() { + let offsets: Buffer = vec![0, 1].into(); + let values: Buffer = b"a".to_vec().into(); + let array = BinaryArray::::new( + ArrowDataType::Binary, + offsets.clone().try_into().unwrap(), + values.clone(), + Some(Bitmap::from([true])), + ); + assert!(array.into_mut().is_left()) +} + +#[test] +#[allow(clippy::redundant_clone)] +fn shared_offsets() { + let offsets: Buffer = vec![0, 1].into(); + let array = BinaryArray::::new( + ArrowDataType::Binary, + offsets.clone().try_into().unwrap(), + b"a".to_vec().into(), + Some(Bitmap::from([true])), + ); + assert!(array.into_mut().is_left()) +} + +#[test] +#[allow(clippy::redundant_clone)] +fn shared_all() { + let array = BinaryArray::::from([Some("hello"), Some(" "), None]); + assert!(array.clone().into_mut().is_left()) +} diff --git a/crates/polars/tests/it/arrow/array/boolean/mod.rs b/crates/polars/tests/it/arrow/array/boolean/mod.rs new file mode 100644 index 0000000000000..8b3a4e1e1b709 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/boolean/mod.rs @@ -0,0 +1,146 @@ +use arrow::array::{Array, BooleanArray}; +use arrow::bitmap::Bitmap; +use arrow::datatypes::ArrowDataType; +use polars_error::PolarsResult; + +mod mutable; + +#[test] +fn basics() { + let data = vec![Some(true), None, Some(false)]; + + let array: BooleanArray = data.into_iter().collect(); + + assert_eq!(array.data_type(), &ArrowDataType::Boolean); + + assert!(array.value(0)); + assert!(!array.value(1)); + assert!(!array.value(2)); + assert!(!unsafe { array.value_unchecked(2) }); + assert_eq!(array.values(), &Bitmap::from_u8_slice([0b00000001], 3)); + assert_eq!( + array.validity(), + Some(&Bitmap::from_u8_slice([0b00000101], 3)) + ); + assert!(array.is_valid(0)); + assert!(!array.is_valid(1)); + assert!(array.is_valid(2)); + + let array2 = BooleanArray::new( + ArrowDataType::Boolean, + array.values().clone(), + array.validity().cloned(), + ); + assert_eq!(array, array2); + + let array = array.sliced(1, 2); + assert!(!array.value(0)); + assert!(!array.value(1)); +} + +#[test] +fn try_new_invalid() { + assert!(BooleanArray::try_new(ArrowDataType::Int32, [true].into(), None).is_err()); + assert!(BooleanArray::try_new( + ArrowDataType::Boolean, + [true].into(), + Some([false, true].into()) + ) + .is_err()); +} + +#[test] +fn with_validity() { + let bitmap = Bitmap::from([true, false, true]); + let a = BooleanArray::new(ArrowDataType::Boolean, bitmap, None); + let a = a.with_validity(Some(Bitmap::from([true, false, true]))); + assert!(a.validity().is_some()); +} + +#[test] +fn debug() { + let array = BooleanArray::from([Some(true), None, Some(false)]); + assert_eq!(format!("{array:?}"), "BooleanArray[true, None, false]"); +} + +#[test] +fn into_mut_valid() { + let bitmap = Bitmap::from([true, false, true]); + let a = BooleanArray::new(ArrowDataType::Boolean, bitmap, None); + let _ = a.into_mut().right().unwrap(); + + let bitmap = Bitmap::from([true, false, true]); + let validity = Bitmap::from([true, false, true]); + let a = BooleanArray::new(ArrowDataType::Boolean, bitmap, Some(validity)); + let _ = a.into_mut().right().unwrap(); +} + +#[test] +fn into_mut_invalid() { + let bitmap = Bitmap::from([true, false, true]); + let _other = bitmap.clone(); // values is shared + let a = BooleanArray::new(ArrowDataType::Boolean, bitmap, None); + let _ = a.into_mut().left().unwrap(); + + let bitmap = Bitmap::from([true, false, true]); + let validity = Bitmap::from([true, false, true]); + let _other = validity.clone(); // validity is shared + let a = BooleanArray::new(ArrowDataType::Boolean, bitmap, Some(validity)); + let _ = a.into_mut().left().unwrap(); +} + +#[test] +fn empty() { + let array = BooleanArray::new_empty(ArrowDataType::Boolean); + assert_eq!(array.values().len(), 0); + assert_eq!(array.validity(), None); +} + +#[test] +fn from_trusted_len_iter() { + let iter = std::iter::repeat(true).take(2).map(Some); + let a = BooleanArray::from_trusted_len_iter(iter.clone()); + assert_eq!(a.len(), 2); + let a = unsafe { BooleanArray::from_trusted_len_iter_unchecked(iter) }; + assert_eq!(a.len(), 2); +} + +#[test] +fn try_from_trusted_len_iter() { + let iter = std::iter::repeat(true) + .take(2) + .map(Some) + .map(PolarsResult::Ok); + let a = BooleanArray::try_from_trusted_len_iter(iter.clone()).unwrap(); + assert_eq!(a.len(), 2); + let a = unsafe { BooleanArray::try_from_trusted_len_iter_unchecked(iter).unwrap() }; + assert_eq!(a.len(), 2); +} + +#[test] +fn from_trusted_len_values_iter() { + let iter = std::iter::repeat(true).take(2); + let a = BooleanArray::from_trusted_len_values_iter(iter.clone()); + assert_eq!(a.len(), 2); + let a = unsafe { BooleanArray::from_trusted_len_values_iter_unchecked(iter) }; + assert_eq!(a.len(), 2); +} + +#[test] +fn from_iter() { + let iter = std::iter::repeat(true).take(2).map(Some); + let a: BooleanArray = iter.collect(); + assert_eq!(a.len(), 2); +} + +#[test] +fn into_iter() { + let data = vec![Some(true), None, Some(false)]; + let rev = data.clone().into_iter().rev(); + + let array: BooleanArray = data.clone().into_iter().collect(); + + assert_eq!(array.clone().into_iter().collect::>(), data); + + assert!(array.into_iter().rev().eq(rev)) +} diff --git a/crates/polars/tests/it/arrow/array/boolean/mutable.rs b/crates/polars/tests/it/arrow/array/boolean/mutable.rs new file mode 100644 index 0000000000000..1071a1ed8c37d --- /dev/null +++ b/crates/polars/tests/it/arrow/array/boolean/mutable.rs @@ -0,0 +1,177 @@ +use arrow::array::{MutableArray, MutableBooleanArray, TryExtendFromSelf}; +use arrow::bitmap::MutableBitmap; +use arrow::datatypes::ArrowDataType; +use polars_error::PolarsResult; + +#[test] +fn set() { + let mut a = MutableBooleanArray::from(&[Some(false), Some(true), Some(false)]); + + a.set(1, None); + a.set(0, Some(true)); + assert_eq!( + a, + MutableBooleanArray::from([Some(true), None, Some(false)]) + ); + assert_eq!(a.values(), &MutableBitmap::from([true, false, false])); +} + +#[test] +fn push() { + let mut a = MutableBooleanArray::new(); + a.push(Some(true)); + a.push(Some(false)); + a.push(None); + a.push_null(); + assert_eq!( + a, + MutableBooleanArray::from([Some(true), Some(false), None, None]) + ); +} + +#[test] +fn pop() { + let mut a = MutableBooleanArray::new(); + a.push(Some(true)); + a.push(Some(false)); + a.push(None); + a.push_null(); + + assert_eq!(a.pop(), None); + assert_eq!(a.len(), 3); + assert_eq!(a.pop(), None); + assert_eq!(a.len(), 2); + assert_eq!(a.pop(), Some(false)); + assert_eq!(a.len(), 1); + assert_eq!(a.pop(), Some(true)); + assert_eq!(a.len(), 0); + assert_eq!(a.pop(), None); + assert_eq!(a.len(), 0); +} + +#[test] +fn pop_all_some() { + let mut a = MutableBooleanArray::new(); + for _ in 0..4 { + a.push(Some(true)); + } + + for _ in 0..4 { + a.push(Some(false)); + } + + a.push(Some(true)); + + assert_eq!(a.pop(), Some(true)); + assert_eq!(a.pop(), Some(false)); + assert_eq!(a.pop(), Some(false)); + assert_eq!(a.pop(), Some(false)); + assert_eq!(a.len(), 5); + + assert_eq!( + a, + MutableBooleanArray::from([Some(true), Some(true), Some(true), Some(true), Some(false)]) + ); +} + +#[test] +fn from_trusted_len_iter() { + let iter = std::iter::repeat(true).take(2).map(Some); + let a = MutableBooleanArray::from_trusted_len_iter(iter); + assert_eq!(a, MutableBooleanArray::from([Some(true), Some(true)])); +} + +#[test] +fn from_iter() { + let iter = std::iter::repeat(true).take(2).map(Some); + let a: MutableBooleanArray = iter.collect(); + assert_eq!(a, MutableBooleanArray::from([Some(true), Some(true)])); +} + +#[test] +fn try_from_trusted_len_iter() { + let iter = vec![Some(true), Some(true), None] + .into_iter() + .map(PolarsResult::Ok); + let a = MutableBooleanArray::try_from_trusted_len_iter(iter).unwrap(); + assert_eq!(a, MutableBooleanArray::from([Some(true), Some(true), None])); +} + +#[test] +fn reserve() { + let mut a = MutableBooleanArray::try_new( + ArrowDataType::Boolean, + MutableBitmap::new(), + Some(MutableBitmap::new()), + ) + .unwrap(); + + a.reserve(10); + assert!(a.validity().unwrap().capacity() > 0); + assert!(a.values().capacity() > 0) +} + +#[test] +fn extend_trusted_len() { + let mut a = MutableBooleanArray::new(); + + a.extend_trusted_len(vec![Some(true), Some(false)].into_iter()); + assert_eq!(a.validity(), None); + + a.extend_trusted_len(vec![None, Some(true)].into_iter()); + assert_eq!( + a.validity(), + Some(&MutableBitmap::from([true, true, false, true])) + ); + assert_eq!(a.values(), &MutableBitmap::from([true, false, false, true])); +} + +#[test] +fn extend_trusted_len_values() { + let mut a = MutableBooleanArray::new(); + + a.extend_trusted_len_values(vec![true, true, false].into_iter()); + assert_eq!(a.validity(), None); + assert_eq!(a.values(), &MutableBitmap::from([true, true, false])); + + let mut a = MutableBooleanArray::new(); + a.push(None); + a.extend_trusted_len_values(vec![true, false].into_iter()); + assert_eq!( + a.validity(), + Some(&MutableBitmap::from([false, true, true])) + ); + assert_eq!(a.values(), &MutableBitmap::from([false, true, false])); +} + +#[test] +fn into_iter() { + let ve = MutableBitmap::from([true, false]) + .into_iter() + .collect::>(); + assert_eq!(ve, vec![true, false]); + let ve = MutableBitmap::from([true, false]) + .iter() + .collect::>(); + assert_eq!(ve, vec![true, false]); +} + +#[test] +fn shrink_to_fit() { + let mut a = MutableBitmap::with_capacity(100); + a.push(true); + a.shrink_to_fit(); + assert_eq!(a.capacity(), 8); +} + +#[test] +fn extend_from_self() { + let mut a = MutableBooleanArray::from([Some(true), None]); + + a.try_extend_from_self(&a.clone()).unwrap(); + + assert_eq!( + a, + MutableBooleanArray::from([Some(true), None, Some(true), None]) + ); +} diff --git a/crates/polars/tests/it/arrow/array/dictionary/mod.rs b/crates/polars/tests/it/arrow/array/dictionary/mod.rs new file mode 100644 index 0000000000000..e14b065e7536d --- /dev/null +++ b/crates/polars/tests/it/arrow/array/dictionary/mod.rs @@ -0,0 +1,214 @@ +mod mutable; + +use arrow::array::*; +use arrow::datatypes::ArrowDataType; + +#[test] +fn try_new_ok() { + let values = Utf8Array::::from_slice(["a", "aa"]); + let data_type = + ArrowDataType::Dictionary(i32::KEY_TYPE, Box::new(values.data_type().clone()), false); + let array = DictionaryArray::try_new( + data_type, + PrimitiveArray::from_vec(vec![1, 0]), + values.boxed(), + ) + .unwrap(); + + assert_eq!(array.keys(), &PrimitiveArray::from_vec(vec![1i32, 0])); + assert_eq!( + &Utf8Array::::from_slice(["a", "aa"]) as &dyn Array, + array.values().as_ref(), + ); + assert!(!array.is_ordered()); + + assert_eq!(format!("{array:?}"), "DictionaryArray[aa, a]"); +} + +#[test] +fn try_new_incorrect_key() { + let values = Utf8Array::::from_slice(["a", "aa"]); + let data_type = + ArrowDataType::Dictionary(i16::KEY_TYPE, Box::new(values.data_type().clone()), false); + + let r = DictionaryArray::try_new( + data_type, + PrimitiveArray::from_vec(vec![1, 0]), + values.boxed(), + ) + .is_err(); + + assert!(r); +} + +#[test] +fn try_new_nulls() { + let key: Option = None; + let keys = PrimitiveArray::from_iter([key]); + let value: &[&str] = &[]; + let values = Utf8Array::::from_slice(value); + + let data_type = + ArrowDataType::Dictionary(u32::KEY_TYPE, Box::new(values.data_type().clone()), false); + let r = DictionaryArray::try_new(data_type, keys, values.boxed()).is_ok(); + + assert!(r); +} + +#[test] +fn try_new_incorrect_dt() { + let values = Utf8Array::::from_slice(["a", "aa"]); + let data_type = ArrowDataType::Int32; + + let r = DictionaryArray::try_new( + data_type, + PrimitiveArray::from_vec(vec![1, 0]), + values.boxed(), + ) + .is_err(); + + assert!(r); +} + +#[test] +fn try_new_incorrect_values_dt() { + let values = Utf8Array::::from_slice(["a", "aa"]); + let data_type = + ArrowDataType::Dictionary(i32::KEY_TYPE, Box::new(ArrowDataType::LargeUtf8), false); + + let r = DictionaryArray::try_new( + data_type, + PrimitiveArray::from_vec(vec![1, 0]), + values.boxed(), + ) + .is_err(); + + assert!(r); +} + +#[test] +fn try_new_out_of_bounds() { + let values = Utf8Array::::from_slice(["a", "aa"]); + + let r = DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![2, 0]), values.boxed()) + .is_err(); + + assert!(r); +} + +#[test] +fn try_new_out_of_bounds_neg() { + let values = Utf8Array::::from_slice(["a", "aa"]); + + let r = DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![-1, 0]), values.boxed()) + .is_err(); + + assert!(r); +} + +#[test] +fn new_null() { + let dt = ArrowDataType::Dictionary(i16::KEY_TYPE, Box::new(ArrowDataType::Int32), false); + let array = DictionaryArray::::new_null(dt, 2); + + assert_eq!(format!("{array:?}"), "DictionaryArray[None, None]"); +} + +#[test] +fn new_empty() { + let dt = ArrowDataType::Dictionary(i16::KEY_TYPE, Box::new(ArrowDataType::Int32), false); + let array = DictionaryArray::::new_empty(dt); + + assert_eq!(format!("{array:?}"), "DictionaryArray[]"); +} + +#[test] +fn with_validity() { + let values = Utf8Array::::from_slice(["a", "aa"]); + let array = + DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![1, 0]), values.boxed()) + .unwrap(); + + let array = array.with_validity(Some([true, false].into())); + + assert_eq!(format!("{array:?}"), "DictionaryArray[aa, None]"); +} + +#[test] +fn rev_iter() { + let values = Utf8Array::::from_slice(["a", "aa"]); + let array = + DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![1, 0]), values.boxed()) + .unwrap(); + + let mut iter = array.into_iter(); + assert_eq!(iter.by_ref().rev().count(), 2); + assert_eq!(iter.size_hint(), (0, Some(0))); +} + +#[test] +fn iter_values() { + let values = Utf8Array::::from_slice(["a", "aa"]); + let array = + DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![1, 0]), values.boxed()) + .unwrap(); + + let mut iter = array.values_iter(); + assert_eq!(iter.by_ref().count(), 2); + assert_eq!(iter.size_hint(), (0, Some(0))); +} + +#[test] +fn keys_values_iter() { + let values = Utf8Array::::from_slice(["a", "aa"]); + let array = + DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![1, 0]), values.boxed()) + .unwrap(); + + assert_eq!(array.keys_values_iter().collect::>(), vec![1, 0]); +} + +#[test] +fn iter_values_typed() { + let values = Utf8Array::::from_slice(["a", "aa"]); + let array = + DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![1, 0, 0]), values.boxed()) + .unwrap(); + + let iter = array.values_iter_typed::>().unwrap(); + assert_eq!(iter.size_hint(), (3, Some(3))); + assert_eq!(iter.collect::>(), vec!["aa", "a", "a"]); + + let iter = array.iter_typed::>().unwrap(); + assert_eq!(iter.size_hint(), (3, Some(3))); + assert_eq!( + iter.collect::>(), + vec![Some("aa"), Some("a"), Some("a")] + ); +} + +#[test] +#[should_panic] +fn iter_values_typed_panic() { + let values = Utf8Array::::from_iter([Some("a"), Some("aa"), None]); + let array = + DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![1, 0, 0]), values.boxed()) + .unwrap(); + + // should not be iterating values + let iter = array.values_iter_typed::>().unwrap(); + let _ = iter.collect::>(); +} + +#[test] +#[should_panic] +fn iter_values_typed_panic_2() { + let values = Utf8Array::::from_iter([Some("a"), Some("aa"), None]); + let array = + DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![1, 0, 0]), values.boxed()) + .unwrap(); + + // should not be iterating values + let iter = array.iter_typed::>().unwrap(); + let _ = iter.collect::>(); +} diff --git a/crates/polars/tests/it/arrow/array/dictionary/mutable.rs b/crates/polars/tests/it/arrow/array/dictionary/mutable.rs new file mode 100644 index 0000000000000..cc8b0774533ad --- /dev/null +++ b/crates/polars/tests/it/arrow/array/dictionary/mutable.rs @@ -0,0 +1,169 @@ +use std::borrow::Borrow; +use std::fmt::Debug; +use std::hash::Hash; + +use arrow::array::indexable::{AsIndexed, Indexable}; +use arrow::array::*; +use polars_error::PolarsResult; +use polars_utils::aliases::{InitHashMaps, PlHashSet}; + +#[test] +fn primitive() -> PolarsResult<()> { + let data = vec![Some(1), Some(2), Some(1)]; + + let mut a = MutableDictionaryArray::>::new(); + a.try_extend(data)?; + assert_eq!(a.len(), 3); + assert_eq!(a.values().len(), 2); + Ok(()) +} + +#[test] +fn utf8_natural() -> PolarsResult<()> { + let data = vec![Some("a"), Some("b"), Some("a")]; + + let mut a = MutableDictionaryArray::>::new(); + a.try_extend(data)?; + + assert_eq!(a.len(), 3); + assert_eq!(a.values().len(), 2); + Ok(()) +} + +#[test] +fn binary_natural() -> PolarsResult<()> { + let data = vec![ + Some("a".as_bytes()), + Some("b".as_bytes()), + Some("a".as_bytes()), + ]; + + let mut a = MutableDictionaryArray::>::new(); + a.try_extend(data)?; + assert_eq!(a.len(), 3); + assert_eq!(a.values().len(), 2); + Ok(()) +} + +#[test] +fn push_utf8() { + let mut new: MutableDictionaryArray> = MutableDictionaryArray::new(); + + for value in [Some("A"), Some("B"), None, Some("C"), Some("A"), Some("B")] { + new.try_push(value).unwrap(); + } + + assert_eq!( + new.values().values(), + MutableUtf8Array::::from_iter_values(["A", "B", "C"].into_iter()).values() + ); + + let mut expected_keys = MutablePrimitiveArray::::from_slice([0, 1]); + expected_keys.push(None); + expected_keys.push(Some(2)); + expected_keys.push(Some(0)); + expected_keys.push(Some(1)); + assert_eq!(*new.keys(), expected_keys); +} + +#[test] +fn into_empty() { + let mut new: MutableDictionaryArray> = MutableDictionaryArray::new(); + for value in [Some("A"), Some("B"), None, Some("C"), Some("A"), Some("B")] { + new.try_push(value).unwrap(); + } + let values = new.values().clone(); + let empty = new.into_empty(); + assert_eq!(empty.values(), &values); + assert!(empty.is_empty()); +} + +#[test] +fn from_values() { + let mut new: MutableDictionaryArray> = MutableDictionaryArray::new(); + for value in [Some("A"), Some("B"), None, Some("C"), Some("A"), Some("B")] { + new.try_push(value).unwrap(); + } + let mut values = new.values().clone(); + let empty = MutableDictionaryArray::::from_values(values.clone()).unwrap(); + assert_eq!(empty.values(), &values); + assert!(empty.is_empty()); + values.push(Some("A")); + assert!(MutableDictionaryArray::::from_values(values).is_err()); +} + +#[test] +fn try_empty() { + let mut values = MutableUtf8Array::::new(); + MutableDictionaryArray::::try_empty(values.clone()).unwrap(); + values.push(Some("A")); + assert!(MutableDictionaryArray::::try_empty(values.clone()).is_err()); +} + +fn test_push_ex(values: Vec, gen: impl Fn(usize) -> T) +where + M: MutableArray + Indexable + TryPush> + TryExtend> + Default + 'static, + M::Type: Eq + Hash + Debug, + T: AsIndexed + Default + Clone + Eq + Hash, +{ + for is_extend in [false, true] { + let mut set = PlHashSet::new(); + let mut arr = MutableDictionaryArray::::new(); + macro_rules! push { + ($v:expr) => { + if is_extend { + arr.try_extend(std::iter::once($v)) + } else { + arr.try_push($v) + } + }; + } + arr.push_null(); + push!(None).unwrap(); + assert_eq!(arr.len(), 2); + assert_eq!(arr.values().len(), 0); + for (i, v) in values.iter().cloned().enumerate() { + push!(Some(v.clone())).unwrap(); + let is_dup = !set.insert(v.clone()); + if !is_dup { + assert_eq!(arr.values().value_at(i).borrow(), v.as_indexed()); + assert_eq!(arr.keys().value_at(arr.keys().len() - 1), i as u8); + } + assert_eq!(arr.values().len(), set.len()); + assert_eq!(arr.len(), 3 + i); + } + for i in 0..256 - set.len() { + push!(Some(gen(i))).unwrap(); + } + assert!(push!(Some(gen(256))).is_err()); + } +} + +#[test] +fn test_push_utf8_ex() { + test_push_ex::, _>(vec!["a".into(), "b".into(), "a".into()], |i| { + i.to_string() + }) +} + +#[test] +fn test_push_i64_ex() { + test_push_ex::, _>(vec![10, 20, 30, 20], |i| 1000 + i as i64); +} + +#[test] +fn test_big_dict() { + let n = 10; + let strings = (0..10).map(|i| i.to_string()).collect::>(); + let mut arr = MutableDictionaryArray::>::new(); + for s in &strings { + arr.try_push(Some(s)).unwrap(); + } + assert_eq!(arr.values().len(), n); + for _ in 0..10_000 { + for s in &strings { + arr.try_push(Some(s)).unwrap(); + } + } + assert_eq!(arr.values().len(), n); +} diff --git a/crates/polars/tests/it/arrow/array/equal/boolean.rs b/crates/polars/tests/it/arrow/array/equal/boolean.rs new file mode 100644 index 0000000000000..e20be510879f2 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/equal/boolean.rs @@ -0,0 +1,53 @@ +use arrow::array::*; + +use super::test_equal; + +#[test] +fn test_boolean_equal() { + let a = BooleanArray::from_slice([false, false, true]); + let b = BooleanArray::from_slice([false, false, true]); + test_equal(&a, &b, true); + + let b = BooleanArray::from_slice([false, false, false]); + test_equal(&a, &b, false); +} + +#[test] +fn test_boolean_equal_null() { + let a = BooleanArray::from(vec![Some(false), None, None, Some(true)]); + let b = BooleanArray::from(vec![Some(false), None, None, Some(true)]); + test_equal(&a, &b, true); + + let b = BooleanArray::from(vec![None, None, None, Some(true)]); + test_equal(&a, &b, false); + + let b = BooleanArray::from(vec![Some(true), None, None, Some(true)]); + test_equal(&a, &b, false); +} + +#[test] +fn test_boolean_equal_offset() { + let a = BooleanArray::from_slice(vec![false, true, false, true, false, false, true]); + let b = BooleanArray::from_slice(vec![true, false, false, false, true, false, true, true]); + test_equal(&a, &b, false); + + let a_slice = a.sliced(2, 3); + let b_slice = b.sliced(3, 3); + test_equal(&a_slice, &b_slice, true); + + let a_slice = a.sliced(3, 4); + let b_slice = b.sliced(4, 4); + test_equal(&a_slice, &b_slice, false); + + // Elements fill in `u8`'s exactly. + let mut vector = vec![false, false, true, true, true, true, true, true]; + let a = BooleanArray::from_slice(vector.clone()); + let b = BooleanArray::from_slice(vector.clone()); + test_equal(&a, &b, true); + + // Elements fill in `u8`s + suffix bits. + vector.push(true); + let a = BooleanArray::from_slice(vector.clone()); + let b = BooleanArray::from_slice(vector); + test_equal(&a, &b, true); +} diff --git a/crates/polars/tests/it/arrow/array/equal/dictionary.rs b/crates/polars/tests/it/arrow/array/equal/dictionary.rs new file mode 100644 index 0000000000000..b429c71b4e690 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/equal/dictionary.rs @@ -0,0 +1,97 @@ +use arrow::array::*; + +use super::test_equal; + +fn create_dictionary_array(values: &[Option<&str>], keys: &[Option]) -> DictionaryArray { + let keys = Int16Array::from(keys); + let values = Utf8Array::::from(values); + + DictionaryArray::try_from_keys(keys, values.boxed()).unwrap() +} + +#[test] +fn dictionary_equal() { + // (a, b, c), (0, 1, 0, 2) => (a, b, a, c) + let a = create_dictionary_array( + &[Some("a"), Some("b"), Some("c")], + &[Some(0), Some(1), Some(0), Some(2)], + ); + // different representation (values and keys are swapped), same result + let b = create_dictionary_array( + &[Some("a"), Some("c"), Some("b")], + &[Some(0), Some(2), Some(0), Some(1)], + ); + test_equal(&a, &b, true); + + // different len + let b = create_dictionary_array( + &[Some("a"), Some("c"), Some("b")], + &[Some(0), Some(2), Some(1)], + ); + test_equal(&a, &b, false); + + // different key + let b = create_dictionary_array( + &[Some("a"), Some("c"), Some("b")], + &[Some(0), Some(2), Some(0), Some(0)], + ); + test_equal(&a, &b, false); + + // different values, same keys + let b = create_dictionary_array( + &[Some("a"), Some("b"), Some("d")], + &[Some(0), Some(1), Some(0), Some(2)], + ); + test_equal(&a, &b, false); +} + +#[test] +fn dictionary_equal_null() { + // (a, b, c), (1, 2, 1, 3) => (a, b, a, c) + let a = create_dictionary_array( + &[Some("a"), Some("b"), Some("c")], + &[Some(0), None, Some(0), Some(2)], + ); + + // equal to self + test_equal(&a, &a, true); + + // different representation (values and keys are swapped), same result + let b = create_dictionary_array( + &[Some("a"), Some("c"), Some("b")], + &[Some(0), None, Some(0), Some(1)], + ); + test_equal(&a, &b, true); + + // different null position + let b = create_dictionary_array( + &[Some("a"), Some("c"), Some("b")], + &[Some(0), Some(2), Some(0), None], + ); + test_equal(&a, &b, false); + + // different key + let b = create_dictionary_array( + &[Some("a"), Some("c"), Some("b")], + &[Some(0), None, Some(0), Some(0)], + ); + test_equal(&a, &b, false); + + // different values, same keys + let b = create_dictionary_array( + &[Some("a"), Some("b"), Some("d")], + &[Some(0), None, Some(0), Some(2)], + ); + test_equal(&a, &b, false); + + // different nulls in keys and values + let a = create_dictionary_array( + &[Some("a"), Some("b"), None], + &[Some(0), None, Some(0), Some(2)], + ); + let b = create_dictionary_array( + &[Some("a"), Some("b"), Some("c")], + &[Some(0), None, Some(0), None], + ); + test_equal(&a, &b, true); +} diff --git a/crates/polars/tests/it/arrow/array/equal/fixed_size_list.rs b/crates/polars/tests/it/arrow/array/equal/fixed_size_list.rs new file mode 100644 index 0000000000000..04238ab7362fa --- /dev/null +++ b/crates/polars/tests/it/arrow/array/equal/fixed_size_list.rs @@ -0,0 +1,84 @@ +use arrow::array::{ + FixedSizeListArray, MutableFixedSizeListArray, MutablePrimitiveArray, TryExtend, +}; + +use super::test_equal; + +/// Create a fixed size list of 2 value lengths +fn create_fixed_size_list_array, T: AsRef<[Option]>>( + data: T, +) -> FixedSizeListArray { + let data = data.as_ref().iter().map(|x| { + Some(match x { + Some(x) => x.as_ref().iter().map(|x| Some(*x)).collect::>(), + None => std::iter::repeat(None).take(3).collect::>(), + }) + }); + + let mut list = MutableFixedSizeListArray::new(MutablePrimitiveArray::::new(), 3); + list.try_extend(data).unwrap(); + list.into() +} + +#[test] +fn test_fixed_size_list_equal() { + let a = create_fixed_size_list_array([Some(&[1, 2, 3]), Some(&[4, 5, 6])]); + let b = create_fixed_size_list_array([Some(&[1, 2, 3]), Some(&[4, 5, 6])]); + test_equal(&a, &b, true); + + let b = create_fixed_size_list_array([Some(&[1, 2, 3]), Some(&[4, 5, 7])]); + test_equal(&a, &b, false); +} + +// Test the case where null_count > 0 +#[test] +fn test_fixed_list_null() { + let a = + create_fixed_size_list_array([Some(&[1, 2, 3]), None, None, Some(&[4, 5, 6]), None, None]); + /* + let b = create_fixed_size_list_array(&[ + Some(&[1, 2, 3]), + None, + None, + Some(&[4, 5, 6]), + None, + None, + ]); + test_equal(&a, &b, true); + + let b = create_fixed_size_list_array(&[ + Some(&[1, 2, 3]), + None, + Some(&[7, 8, 9]), + Some(&[4, 5, 6]), + None, + None, + ]); + test_equal(&a, &b, false); + */ + + let b = + create_fixed_size_list_array([Some(&[1, 2, 3]), None, None, Some(&[3, 6, 9]), None, None]); + test_equal(&a, &b, false); +} + +#[test] +fn test_fixed_list_offsets() { + // Test the case where offset != 0 + let a = + create_fixed_size_list_array([Some(&[1, 2, 3]), None, None, Some(&[4, 5, 6]), None, None]); + let b = + create_fixed_size_list_array([Some(&[1, 2, 3]), None, None, Some(&[3, 6, 9]), None, None]); + + let a_slice = a.clone().sliced(0, 3); + let b_slice = b.clone().sliced(0, 3); + test_equal(&a_slice, &b_slice, true); + + let a_slice = a.clone().sliced(0, 5); + let b_slice = b.clone().sliced(0, 5); + test_equal(&a_slice, &b_slice, false); + + let a_slice = a.sliced(4, 1); + let b_slice = b.sliced(4, 1); + test_equal(&a_slice, &b_slice, true); +} diff --git a/crates/polars/tests/it/arrow/array/equal/list.rs b/crates/polars/tests/it/arrow/array/equal/list.rs new file mode 100644 index 0000000000000..34370ad5459e9 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/equal/list.rs @@ -0,0 +1,90 @@ +use arrow::array::{Int32Array, ListArray, MutableListArray, MutablePrimitiveArray, TryExtend}; +use arrow::bitmap::Bitmap; +use arrow::datatypes::ArrowDataType; + +use super::test_equal; + +fn create_list_array, T: AsRef<[Option]>>(data: T) -> ListArray { + let iter = data.as_ref().iter().map(|x| { + x.as_ref() + .map(|x| x.as_ref().iter().map(|x| Some(*x)).collect::>()) + }); + let mut array = MutableListArray::>::new(); + array.try_extend(iter).unwrap(); + array.into() +} + +#[test] +fn test_list_equal() { + let a = create_list_array([Some(&[1, 2, 3]), Some(&[4, 5, 6])]); + let b = create_list_array([Some(&[1, 2, 3]), Some(&[4, 5, 6])]); + test_equal(&a, &b, true); + + let b = create_list_array([Some(&[1, 2, 3]), Some(&[4, 5, 7])]); + test_equal(&a, &b, false); +} + +// Test the case where null_count > 0 +#[test] +fn test_list_null() { + let a = create_list_array([Some(&[1, 2]), None, None, Some(&[3, 4]), None, None]); + let b = create_list_array([Some(&[1, 2]), None, None, Some(&[3, 4]), None, None]); + test_equal(&a, &b, true); + + let b = create_list_array([ + Some(&[1, 2]), + None, + Some(&[5, 6]), + Some(&[3, 4]), + None, + None, + ]); + test_equal(&a, &b, false); + + let b = create_list_array([Some(&[1, 2]), None, None, Some(&[3, 5]), None, None]); + test_equal(&a, &b, false); +} + +// Test the case where offset != 0 +#[test] +fn test_list_offsets() { + let a = create_list_array([Some(&[1, 2]), None, None, Some(&[3, 4]), None, None]); + let b = create_list_array([Some(&[1, 2]), None, None, Some(&[3, 5]), None, None]); + + let a_slice = a.clone().sliced(0, 3); + let b_slice = b.clone().sliced(0, 3); + test_equal(&a_slice, &b_slice, true); + + let a_slice = a.clone().sliced(0, 5); + let b_slice = b.clone().sliced(0, 5); + test_equal(&a_slice, &b_slice, false); + + let a_slice = a.sliced(4, 1); + let b_slice = b.sliced(4, 1); + test_equal(&a_slice, &b_slice, true); +} + +#[test] +fn test_bla() { + let offsets = vec![0, 3, 3, 6].try_into().unwrap(); + let data_type = ListArray::::default_datatype(ArrowDataType::Int32); + let values = Box::new(Int32Array::from([ + Some(1), + Some(2), + Some(3), + Some(4), + None, + Some(6), + ])); + let validity = Bitmap::from([true, false, true]); + let lhs = ListArray::::new(data_type, offsets, values, Some(validity)); + let lhs = lhs.sliced(1, 2); + + let offsets = vec![0, 0, 3].try_into().unwrap(); + let data_type = ListArray::::default_datatype(ArrowDataType::Int32); + let values = Box::new(Int32Array::from([Some(4), None, Some(6)])); + let validity = Bitmap::from([false, true]); + let rhs = ListArray::::new(data_type, offsets, values, Some(validity)); + + assert_eq!(lhs, rhs); +} diff --git a/crates/polars/tests/it/arrow/array/equal/mod.rs b/crates/polars/tests/it/arrow/array/equal/mod.rs new file mode 100644 index 0000000000000..87f7ffeff251c --- /dev/null +++ b/crates/polars/tests/it/arrow/array/equal/mod.rs @@ -0,0 +1,50 @@ +use arrow::array::*; + +mod dictionary; +mod fixed_size_list; +mod list; +mod primitive; +mod utf8; + +pub fn test_equal(lhs: &dyn Array, rhs: &dyn Array, expected: bool) { + // equality is symmetric + assert!(equal(lhs, lhs), "\n{lhs:?}\n{lhs:?}"); + assert!(equal(rhs, rhs), "\n{rhs:?}\n{rhs:?}"); + + assert_eq!(equal(lhs, rhs), expected, "\n{lhs:?}\n{rhs:?}"); + assert_eq!(equal(rhs, lhs), expected, "\n{rhs:?}\n{lhs:?}"); +} + +#[allow(clippy::type_complexity)] +fn binary_cases() -> Vec<(Vec>, Vec>, bool)> { + let base = vec![ + Some("hello".to_owned()), + None, + None, + Some("world".to_owned()), + None, + None, + ]; + let not_base = vec![ + Some("hello".to_owned()), + Some("foo".to_owned()), + None, + Some("world".to_owned()), + None, + None, + ]; + vec![ + ( + vec![Some("hello".to_owned()), Some("world".to_owned())], + vec![Some("hello".to_owned()), Some("world".to_owned())], + true, + ), + ( + vec![Some("hello".to_owned()), Some("world".to_owned())], + vec![Some("hello".to_owned()), Some("arrow".to_owned())], + false, + ), + (base.clone(), base.clone(), true), + (base, not_base, false), + ] +} diff --git a/crates/polars/tests/it/arrow/array/equal/primitive.rs b/crates/polars/tests/it/arrow/array/equal/primitive.rs new file mode 100644 index 0000000000000..e50711eb97284 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/equal/primitive.rs @@ -0,0 +1,90 @@ +use arrow::array::*; + +use super::test_equal; + +#[test] +fn test_primitive() { + let cases = vec![ + ( + vec![Some(1), Some(2), Some(3)], + vec![Some(1), Some(2), Some(3)], + true, + ), + ( + vec![Some(1), Some(2), Some(3)], + vec![Some(1), Some(2), Some(4)], + false, + ), + ( + vec![Some(1), Some(2), None], + vec![Some(1), Some(2), None], + true, + ), + ( + vec![Some(1), None, Some(3)], + vec![Some(1), Some(2), None], + false, + ), + ( + vec![Some(1), None, None], + vec![Some(1), Some(2), None], + false, + ), + ]; + + for (lhs, rhs, expected) in cases { + let lhs = Int32Array::from(&lhs); + let rhs = Int32Array::from(&rhs); + test_equal(&lhs, &rhs, expected); + } +} + +#[test] +fn test_primitive_slice() { + let cases = vec![ + ( + vec![Some(1), Some(2), Some(3)], + (0, 1), + vec![Some(1), Some(2), Some(3)], + (0, 1), + true, + ), + ( + vec![Some(1), Some(2), Some(3)], + (1, 1), + vec![Some(1), Some(2), Some(3)], + (2, 1), + false, + ), + ( + vec![Some(1), Some(2), None], + (1, 1), + vec![Some(1), None, Some(2)], + (2, 1), + true, + ), + ( + vec![None, Some(2), None], + (1, 1), + vec![None, None, Some(2)], + (2, 1), + true, + ), + ( + vec![Some(1), None, Some(2), None, Some(3)], + (2, 2), + vec![None, Some(2), None, Some(3)], + (1, 2), + true, + ), + ]; + + for (lhs, slice_lhs, rhs, slice_rhs, expected) in cases { + let lhs = Int32Array::from(&lhs); + let lhs = lhs.sliced(slice_lhs.0, slice_lhs.1); + let rhs = Int32Array::from(&rhs); + let rhs = rhs.sliced(slice_rhs.0, slice_rhs.1); + + test_equal(&lhs, &rhs, expected); + } +} diff --git a/crates/polars/tests/it/arrow/array/equal/utf8.rs b/crates/polars/tests/it/arrow/array/equal/utf8.rs new file mode 100644 index 0000000000000..a9f9e6cff0694 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/equal/utf8.rs @@ -0,0 +1,26 @@ +use arrow::array::*; +use arrow::offset::Offset; + +use super::{binary_cases, test_equal}; + +fn test_generic_string_equal() { + let cases = binary_cases(); + + for (lhs, rhs, expected) in cases { + let lhs = lhs.iter().map(|x| x.as_deref()); + let rhs = rhs.iter().map(|x| x.as_deref()); + let lhs = Utf8Array::::from_trusted_len_iter(lhs); + let rhs = Utf8Array::::from_trusted_len_iter(rhs); + test_equal(&lhs, &rhs, expected); + } +} + +#[test] +fn utf8_equal() { + test_generic_string_equal::() +} + +#[test] +fn large_utf8_equal() { + test_generic_string_equal::() +} diff --git a/crates/polars/tests/it/arrow/array/fixed_size_binary/mod.rs b/crates/polars/tests/it/arrow/array/fixed_size_binary/mod.rs new file mode 100644 index 0000000000000..12019be642059 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/fixed_size_binary/mod.rs @@ -0,0 +1,103 @@ +use arrow::array::FixedSizeBinaryArray; +use arrow::bitmap::Bitmap; +use arrow::buffer::Buffer; +use arrow::datatypes::ArrowDataType; + +mod mutable; + +#[test] +fn basics() { + let array = FixedSizeBinaryArray::new( + ArrowDataType::FixedSizeBinary(2), + Buffer::from(vec![1, 2, 3, 4, 5, 6]), + Some(Bitmap::from([true, false, true])), + ); + assert_eq!(array.size(), 2); + assert_eq!(array.len(), 3); + assert_eq!(array.validity(), Some(&Bitmap::from([true, false, true]))); + + assert_eq!(array.value(0), [1, 2]); + assert_eq!(array.value(2), [5, 6]); + + let array = array.sliced(1, 2); + + assert_eq!(array.value(1), [5, 6]); +} + +#[test] +fn with_validity() { + let a = FixedSizeBinaryArray::new( + ArrowDataType::FixedSizeBinary(2), + vec![1, 2, 3, 4, 5, 6].into(), + None, + ); + let a = a.with_validity(Some(Bitmap::from([true, false, true]))); + assert!(a.validity().is_some()); +} + +#[test] +fn debug() { + let a = FixedSizeBinaryArray::new( + ArrowDataType::FixedSizeBinary(2), + vec![1, 2, 3, 4, 5, 6].into(), + Some(Bitmap::from([true, false, true])), + ); + assert_eq!(format!("{a:?}"), "FixedSizeBinary(2)[[1, 2], None, [5, 6]]"); +} + +#[test] +fn empty() { + let array = FixedSizeBinaryArray::new_empty(ArrowDataType::FixedSizeBinary(2)); + assert_eq!(array.values().len(), 0); + assert_eq!(array.validity(), None); +} + +#[test] +fn null() { + let array = FixedSizeBinaryArray::new_null(ArrowDataType::FixedSizeBinary(2), 2); + assert_eq!(array.values().len(), 4); + assert_eq!(array.validity().cloned(), Some([false, false].into())); +} + +#[test] +fn from_iter() { + let iter = std::iter::repeat(vec![1u8, 2]).take(2).map(Some); + let a = FixedSizeBinaryArray::from_iter(iter, 2); + assert_eq!(a.len(), 2); +} + +#[test] +fn wrong_size() { + let values = Buffer::from(b"abb".to_vec()); + assert!( + FixedSizeBinaryArray::try_new(ArrowDataType::FixedSizeBinary(2), values, None).is_err() + ); +} + +#[test] +fn wrong_len() { + let values = Buffer::from(b"abba".to_vec()); + let validity = Some([true, false, false].into()); // it should be 2 + assert!( + FixedSizeBinaryArray::try_new(ArrowDataType::FixedSizeBinary(2), values, validity).is_err() + ); +} + +#[test] +fn wrong_data_type() { + let values = Buffer::from(b"abba".to_vec()); + assert!(FixedSizeBinaryArray::try_new(ArrowDataType::Binary, values, None).is_err()); +} + +#[test] +fn to() { + let values = Buffer::from(b"abba".to_vec()); + let a = FixedSizeBinaryArray::new(ArrowDataType::FixedSizeBinary(2), values, None); + + let extension = ArrowDataType::Extension( + "a".to_string(), + Box::new(ArrowDataType::FixedSizeBinary(2)), + None, + ); + let _ = a.to(extension); +} diff --git a/crates/polars/tests/it/arrow/array/fixed_size_binary/mutable.rs b/crates/polars/tests/it/arrow/array/fixed_size_binary/mutable.rs new file mode 100644 index 0000000000000..316157087fbbc --- /dev/null +++ b/crates/polars/tests/it/arrow/array/fixed_size_binary/mutable.rs @@ -0,0 +1,173 @@ +use arrow::array::*; +use arrow::bitmap::{Bitmap, MutableBitmap}; +use arrow::datatypes::ArrowDataType; + +#[test] +fn basic() { + let a = MutableFixedSizeBinaryArray::try_new( + ArrowDataType::FixedSizeBinary(2), + Vec::from([1, 2, 3, 4]), + None, + ) + .unwrap(); + assert_eq!(a.len(), 2); + assert_eq!(a.data_type(), &ArrowDataType::FixedSizeBinary(2)); + assert_eq!(a.values(), &Vec::from([1, 2, 3, 4])); + assert_eq!(a.validity(), None); + assert_eq!(a.value(1), &[3, 4]); + assert_eq!(unsafe { a.value_unchecked(1) }, &[3, 4]); +} + +#[allow(clippy::eq_op)] +#[test] +fn equal() { + let a = MutableFixedSizeBinaryArray::try_new( + ArrowDataType::FixedSizeBinary(2), + Vec::from([1, 2, 3, 4]), + None, + ) + .unwrap(); + assert_eq!(a, a); + let b = MutableFixedSizeBinaryArray::try_new( + ArrowDataType::FixedSizeBinary(2), + Vec::from([1, 2]), + None, + ) + .unwrap(); + assert_eq!(b, b); + assert!(a != b); + let a = MutableFixedSizeBinaryArray::try_new( + ArrowDataType::FixedSizeBinary(2), + Vec::from([1, 2, 3, 4]), + Some(MutableBitmap::from([true, false])), + ) + .unwrap(); + let b = MutableFixedSizeBinaryArray::try_new( + ArrowDataType::FixedSizeBinary(2), + Vec::from([1, 2, 3, 4]), + Some(MutableBitmap::from([false, true])), + ) + .unwrap(); + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); +} + +#[test] +fn try_from_iter() { + let array = MutableFixedSizeBinaryArray::try_from_iter( + vec![Some(b"ab"), Some(b"bc"), None, Some(b"fh")], + 2, + ) + .unwrap(); + assert_eq!(array.len(), 4); +} + +#[test] +fn push_null() { + let mut array = MutableFixedSizeBinaryArray::new(2); + array.push::<&[u8]>(None); + + let array: FixedSizeBinaryArray = array.into(); + assert_eq!(array.validity(), Some(&Bitmap::from([false]))); +} + +#[test] +fn pop() { + let mut a = MutableFixedSizeBinaryArray::new(2); + a.push(Some(b"aa")); + a.push::<&[u8]>(None); + a.push(Some(b"bb")); + a.push::<&[u8]>(None); + + assert_eq!(a.pop(), None); + assert_eq!(a.len(), 3); + assert_eq!(a.pop(), Some(b"bb".to_vec())); + assert_eq!(a.len(), 2); + assert_eq!(a.pop(), None); + assert_eq!(a.len(), 1); + assert_eq!(a.pop(), Some(b"aa".to_vec())); + assert!(a.is_empty()); + assert_eq!(a.pop(), None); + assert!(a.is_empty()); +} + +#[test] +fn pop_all_some() { + let mut a = MutableFixedSizeBinaryArray::new(2); + a.push(Some(b"aa")); + a.push(Some(b"bb")); + a.push(Some(b"cc")); + a.push(Some(b"dd")); + + for _ in 0..4 { + a.push(Some(b"11")); + } + + a.push(Some(b"22")); + + assert_eq!(a.pop(), Some(b"22".to_vec())); + assert_eq!(a.pop(), Some(b"11".to_vec())); + assert_eq!(a.pop(), Some(b"11".to_vec())); + assert_eq!(a.pop(), Some(b"11".to_vec())); + assert_eq!(a.len(), 5); + + assert_eq!( + a, + MutableFixedSizeBinaryArray::try_from_iter( + vec![ + Some(b"aa"), + Some(b"bb"), + Some(b"cc"), + Some(b"dd"), + Some(b"11"), + ], + 2, + ) + .unwrap() + ); +} + +#[test] +fn as_arc() { + let mut array = MutableFixedSizeBinaryArray::try_from_iter( + vec![Some(b"ab"), Some(b"bc"), None, Some(b"fh")], + 2, + ) + .unwrap(); + + let array = array.as_arc(); + assert_eq!(array.len(), 4); +} + +#[test] +fn as_box() { + let mut array = MutableFixedSizeBinaryArray::try_from_iter( + vec![Some(b"ab"), Some(b"bc"), None, Some(b"fh")], + 2, + ) + .unwrap(); + + let array = array.as_box(); + assert_eq!(array.len(), 4); +} + +#[test] +fn shrink_to_fit_and_capacity() { + let mut array = MutableFixedSizeBinaryArray::with_capacity(2, 100); + array.push(Some([1, 2])); + array.shrink_to_fit(); + assert_eq!(array.capacity(), 1); +} + +#[test] +fn extend_from_self() { + let mut a = MutableFixedSizeBinaryArray::from([Some([1u8, 2u8]), None]); + + a.try_extend_from_self(&a.clone()).unwrap(); + + assert_eq!( + a, + MutableFixedSizeBinaryArray::from([Some([1u8, 2u8]), None, Some([1u8, 2u8]), None]) + ); +} diff --git a/crates/polars/tests/it/arrow/array/fixed_size_list/mod.rs b/crates/polars/tests/it/arrow/array/fixed_size_list/mod.rs new file mode 100644 index 0000000000000..d178b27e190b8 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/fixed_size_list/mod.rs @@ -0,0 +1,102 @@ +mod mutable; + +use arrow::array::*; +use arrow::bitmap::Bitmap; +use arrow::datatypes::{ArrowDataType, Field}; + +fn data() -> FixedSizeListArray { + let values = Int32Array::from_slice([10, 20, 0, 0]); + + FixedSizeListArray::try_new( + ArrowDataType::FixedSizeList( + Box::new(Field::new("a", values.data_type().clone(), true)), + 2, + ), + values.boxed(), + Some([true, false].into()), + ) + .unwrap() +} + +#[test] +fn basics() { + let array = data(); + assert_eq!(array.size(), 2); + assert_eq!(array.len(), 2); + assert_eq!(array.validity(), Some(&Bitmap::from([true, false]))); + + assert_eq!(array.value(0).as_ref(), Int32Array::from_slice([10, 20])); + assert_eq!(array.value(1).as_ref(), Int32Array::from_slice([0, 0])); + + let array = array.sliced(1, 1); + + assert_eq!(array.value(0).as_ref(), Int32Array::from_slice([0, 0])); +} + +#[test] +fn with_validity() { + let array = data(); + + let a = array.with_validity(None); + assert!(a.validity().is_none()); +} + +#[test] +fn debug() { + let array = data(); + + assert_eq!(format!("{array:?}"), "FixedSizeListArray[[10, 20], None]"); +} + +#[test] +fn empty() { + let array = FixedSizeListArray::new_empty(ArrowDataType::FixedSizeList( + Box::new(Field::new("a", ArrowDataType::Int32, true)), + 2, + )); + assert_eq!(array.values().len(), 0); + assert_eq!(array.validity(), None); +} + +#[test] +fn null() { + let array = FixedSizeListArray::new_null( + ArrowDataType::FixedSizeList(Box::new(Field::new("a", ArrowDataType::Int32, true)), 2), + 2, + ); + assert_eq!(array.values().len(), 4); + assert_eq!(array.validity().cloned(), Some([false, false].into())); +} + +#[test] +fn wrong_size() { + let values = Int32Array::from_slice([10, 20, 0]); + assert!(FixedSizeListArray::try_new( + ArrowDataType::FixedSizeList(Box::new(Field::new("a", ArrowDataType::Int32, true)), 2), + values.boxed(), + None + ) + .is_err()); +} + +#[test] +fn wrong_len() { + let values = Int32Array::from_slice([10, 20, 0]); + assert!(FixedSizeListArray::try_new( + ArrowDataType::FixedSizeList(Box::new(Field::new("a", ArrowDataType::Int32, true)), 2), + values.boxed(), + Some([true, false, false].into()), // it should be 2 + ) + .is_err()); +} + +#[test] +fn wrong_data_type() { + let values = Int32Array::from_slice([10, 20, 0]); + assert!(FixedSizeListArray::try_new( + ArrowDataType::Binary, + values.boxed(), + Some([true, false, false].into()), // it should be 2 + ) + .is_err()); +} diff --git a/crates/polars/tests/it/arrow/array/fixed_size_list/mutable.rs b/crates/polars/tests/it/arrow/array/fixed_size_list/mutable.rs new file mode 100644 index 0000000000000..23ea53231059c --- /dev/null +++ b/crates/polars/tests/it/arrow/array/fixed_size_list/mutable.rs @@ -0,0 +1,88 @@ +use arrow::array::*; +use arrow::datatypes::{ArrowDataType, Field}; + +#[test] +fn primitive() { + let data = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + Some(vec![None, None, None]), + Some(vec![Some(4), None, Some(6)]), + ]; + + let mut list = MutableFixedSizeListArray::new(MutablePrimitiveArray::::new(), 3); + list.try_extend(data).unwrap(); + let list: FixedSizeListArray = list.into(); + + let a = list.value(0); + let a = a.as_any().downcast_ref::().unwrap(); + + let expected = Int32Array::from(vec![Some(1i32), Some(2), Some(3)]); + assert_eq!(a, &expected); + + let a = list.value(1); + let a = a.as_any().downcast_ref::().unwrap(); + + let expected = Int32Array::from(vec![None, None, None]); + assert_eq!(a, &expected) +} + +#[test] +fn new_with_field() { + let data = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + Some(vec![None, None, None]), + Some(vec![Some(4), None, Some(6)]), + ]; + + let mut list = MutableFixedSizeListArray::new_with_field( + MutablePrimitiveArray::::new(), + "custom_items", + false, + 3, + ); + list.try_extend(data).unwrap(); + let list: FixedSizeListArray = list.into(); + + assert_eq!( + list.data_type(), + &ArrowDataType::FixedSizeList( + Box::new(Field::new("custom_items", ArrowDataType::Int32, false)), + 3 + ) + ); + + let a = list.value(0); + let a = a.as_any().downcast_ref::().unwrap(); + + let expected = Int32Array::from(vec![Some(1i32), Some(2), Some(3)]); + assert_eq!(a, &expected); + + let a = list.value(1); + let a = a.as_any().downcast_ref::().unwrap(); + + let expected = Int32Array::from(vec![None, None, None]); + assert_eq!(a, &expected) +} + +#[test] +fn extend_from_self() { + let data = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + None, + Some(vec![Some(4), None, Some(6)]), + ]; + let mut a = MutableFixedSizeListArray::new(MutablePrimitiveArray::::new(), 3); + a.try_extend(data.clone()).unwrap(); + + a.try_extend_from_self(&a.clone()).unwrap(); + let a: FixedSizeListArray = a.into(); + + let mut expected = data.clone(); + expected.extend(data); + + let mut b = MutableFixedSizeListArray::new(MutablePrimitiveArray::::new(), 3); + b.try_extend(expected).unwrap(); + let b: FixedSizeListArray = b.into(); + + assert_eq!(a, b); +} diff --git a/crates/polars/tests/it/arrow/array/growable/binary.rs b/crates/polars/tests/it/arrow/array/growable/binary.rs new file mode 100644 index 0000000000000..20c0cd31081bb --- /dev/null +++ b/crates/polars/tests/it/arrow/array/growable/binary.rs @@ -0,0 +1,97 @@ +use arrow::array::growable::{Growable, GrowableBinary}; +use arrow::array::BinaryArray; + +#[test] +fn no_offsets() { + let array = BinaryArray::::from([Some("a"), Some("bc"), None, Some("defh")]); + + let mut a = GrowableBinary::new(vec![&array], false, 0); + + unsafe { + a.extend(0, 1, 2); + } + assert_eq!(a.len(), 2); + + let result: BinaryArray = a.into(); + + let expected = BinaryArray::::from([Some("bc"), None]); + assert_eq!(result, expected); +} + +/// tests extending from a variable-sized (strings and binary) array +/// with an offset and nulls +#[test] +fn with_offsets() { + let array = BinaryArray::::from([Some("a"), Some("bc"), None, Some("defh")]); + let array = array.sliced(1, 3); + + let mut a = GrowableBinary::new(vec![&array], false, 0); + + unsafe { + a.extend(0, 0, 3); + } + assert_eq!(a.len(), 3); + + let result: BinaryArray = a.into(); + + let expected = BinaryArray::::from([Some("bc"), None, Some("defh")]); + assert_eq!(result, expected); +} + +#[test] +fn test_string_offsets() { + let array = BinaryArray::::from([Some("a"), Some("bc"), None, Some("defh")]); + let array = array.sliced(1, 3); + + let mut a = GrowableBinary::new(vec![&array], false, 0); + + unsafe { + a.extend(0, 0, 3); + } + assert_eq!(a.len(), 3); + + let result: BinaryArray = a.into(); + + let expected = BinaryArray::::from([Some("bc"), None, Some("defh")]); + assert_eq!(result, expected); +} + +#[test] +fn test_multiple_with_validity() { + let array1 = BinaryArray::::from_slice([b"hello", b"world"]); + let array2 = BinaryArray::::from([Some("1"), None]); + + let mut a = GrowableBinary::new(vec![&array1, &array2], false, 5); + + unsafe { + a.extend(0, 0, 2); + } + unsafe { + a.extend(1, 0, 2); + } + assert_eq!(a.len(), 4); + + let result: BinaryArray = a.into(); + + let expected = BinaryArray::::from([Some("hello"), Some("world"), Some("1"), None]); + assert_eq!(result, expected); +} + +#[test] +fn test_string_null_offset_validity() { + let array = BinaryArray::::from([Some("a"), Some("bc"), None, Some("defh")]); + let array = array.sliced(1, 3); + + let mut a = GrowableBinary::new(vec![&array], true, 0); + + unsafe { + a.extend(0, 1, 2); + } + a.extend_validity(1); + assert_eq!(a.len(), 3); + + let result: BinaryArray = a.into(); + + let expected = BinaryArray::::from([None, Some("defh"), None]); + assert_eq!(result, expected); +} diff --git a/crates/polars/tests/it/arrow/array/growable/boolean.rs b/crates/polars/tests/it/arrow/array/growable/boolean.rs new file mode 100644 index 0000000000000..b6721029cb819 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/growable/boolean.rs @@ -0,0 +1,19 @@ +use arrow::array::growable::{Growable, GrowableBoolean}; +use arrow::array::BooleanArray; + +#[test] +fn test_bool() { + let array = BooleanArray::from(vec![Some(false), Some(true), None, Some(false)]); + + let mut a = GrowableBoolean::new(vec![&array], false, 0); + + unsafe { + a.extend(0, 1, 2); + } + assert_eq!(a.len(), 2); + + let result: BooleanArray = a.into(); + + let expected = BooleanArray::from(vec![Some(true), None]); + assert_eq!(result, expected); +} diff --git a/crates/polars/tests/it/arrow/array/growable/dictionary.rs b/crates/polars/tests/it/arrow/array/growable/dictionary.rs new file mode 100644 index 0000000000000..e2a48275d7aed --- /dev/null +++ b/crates/polars/tests/it/arrow/array/growable/dictionary.rs @@ -0,0 +1,72 @@ +use arrow::array::growable::{Growable, GrowableDictionary}; +use arrow::array::*; +use polars_error::PolarsResult; + +#[test] +fn test_single() -> PolarsResult<()> { + let original_data = vec![Some("a"), Some("b"), Some("a")]; + + let data = original_data.clone(); + let mut array = MutableDictionaryArray::>::new(); + array.try_extend(data)?; + let array = array.into(); + + // same values, less keys + let expected = DictionaryArray::try_from_keys( + PrimitiveArray::from_vec(vec![1, 0]), + Box::new(Utf8Array::::from(&original_data)), + ) + .unwrap(); + + let mut growable = GrowableDictionary::new(&[&array], false, 0); + + unsafe { + growable.extend(0, 1, 2); + } + assert_eq!(growable.len(), 2); + + let result: DictionaryArray = growable.into(); + + assert_eq!(result, expected); + Ok(()) +} + +#[test] +fn test_multi() -> PolarsResult<()> { + let mut original_data1 = vec![Some("a"), Some("b"), None, Some("a")]; + let original_data2 = vec![Some("c"), Some("b"), None, Some("a")]; + + let data1 = original_data1.clone(); + let data2 = original_data2.clone(); + + let mut array1 = MutableDictionaryArray::>::new(); + array1.try_extend(data1)?; + let array1: DictionaryArray = array1.into(); + + let mut array2 = MutableDictionaryArray::>::new(); + array2.try_extend(data2)?; + let array2: DictionaryArray = array2.into(); + + // same values, less keys + original_data1.extend(original_data2.iter().cloned()); + let expected = DictionaryArray::try_from_keys( + PrimitiveArray::from(&[Some(1), None, Some(3), None]), + Utf8Array::::from_slice(["a", "b", "c", "b", "a"]).boxed(), + ) + .unwrap(); + + let mut growable = GrowableDictionary::new(&[&array1, &array2], false, 0); + + unsafe { + growable.extend(0, 1, 2); + } + unsafe { + growable.extend(1, 1, 2); + } + assert_eq!(growable.len(), 4); + + let result: DictionaryArray = growable.into(); + + assert_eq!(result, expected); + Ok(()) +} diff --git a/crates/polars/tests/it/arrow/array/growable/fixed_binary.rs b/crates/polars/tests/it/arrow/array/growable/fixed_binary.rs new file mode 100644 index 0000000000000..9ebb631f682c4 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/growable/fixed_binary.rs @@ -0,0 +1,146 @@ +use arrow::array::growable::{Growable, GrowableFixedSizeBinary}; +use arrow::array::FixedSizeBinaryArray; + +/// tests extending from a variable-sized (strings and binary) array w/ offset with nulls +#[test] +fn basic() { + let array = + FixedSizeBinaryArray::from_iter(vec![Some(b"ab"), Some(b"bc"), None, Some(b"de")], 2); + + let mut a = GrowableFixedSizeBinary::new(vec![&array], false, 0); + + unsafe { + a.extend(0, 1, 2); + } + assert_eq!(a.len(), 2); + + let result: FixedSizeBinaryArray = a.into(); + + let expected = FixedSizeBinaryArray::from_iter(vec![Some("bc"), None], 2); + assert_eq!(result, expected); +} + +/// tests extending from a variable-sized (strings and binary) array +/// with an offset and nulls +#[test] +fn offsets() { + let array = + FixedSizeBinaryArray::from_iter(vec![Some(b"ab"), Some(b"bc"), None, Some(b"fh")], 2); + let array = array.sliced(1, 3); + + let mut a = GrowableFixedSizeBinary::new(vec![&array], false, 0); + + unsafe { + a.extend(0, 0, 3); + } + assert_eq!(a.len(), 3); + + let result: FixedSizeBinaryArray = a.into(); + + let expected = FixedSizeBinaryArray::from_iter(vec![Some(b"bc"), None, Some(b"fh")], 2); + assert_eq!(result, expected); +} + +#[test] +fn multiple_with_validity() { + let array1 = FixedSizeBinaryArray::from_iter(vec![Some("hello"), Some("world")], 5); + let array2 = FixedSizeBinaryArray::from_iter(vec![Some("12345"), None], 5); + + let mut a = GrowableFixedSizeBinary::new(vec![&array1, &array2], false, 5); + + unsafe { + a.extend(0, 0, 2); + } + unsafe { + a.extend(1, 0, 2); + } + assert_eq!(a.len(), 4); + + let result: FixedSizeBinaryArray = a.into(); + + let expected = + FixedSizeBinaryArray::from_iter(vec![Some("hello"), Some("world"), Some("12345"), None], 5); + assert_eq!(result, expected); +} + +#[test] +fn null_offset_validity() { + let array = FixedSizeBinaryArray::from_iter(vec![Some("aa"), Some("bc"), None, Some("fh")], 2); + let array = array.sliced(1, 3); + + let mut a = GrowableFixedSizeBinary::new(vec![&array], true, 0); + + unsafe { + a.extend(0, 1, 2); + } + a.extend_validity(1); + assert_eq!(a.len(), 3); + + let result: FixedSizeBinaryArray = a.into(); + + let expected = FixedSizeBinaryArray::from_iter(vec![None, Some("fh"), None], 2); + assert_eq!(result, expected); +} + +#[test] +fn sized_offsets() { + let array = + FixedSizeBinaryArray::from_iter(vec![Some(&[0, 0]), Some(&[0, 1]), Some(&[0, 2])], 2); + let array = array.sliced(1, 2); + // = [[0, 1], [0, 2]] due to the offset = 1 + + let mut a = GrowableFixedSizeBinary::new(vec![&array], false, 0); + + unsafe { + a.extend(0, 1, 1); + } + unsafe { + a.extend(0, 0, 1); + } + assert_eq!(a.len(), 2); + + let result: FixedSizeBinaryArray = a.into(); + + let expected = FixedSizeBinaryArray::from_iter(vec![Some(&[0, 2]), Some(&[0, 1])], 2); + assert_eq!(result, expected); +} + +/// to, as_box, as_arc +#[test] +fn as_box() { + let array = + FixedSizeBinaryArray::from_iter(vec![Some(b"ab"), Some(b"bc"), None, Some(b"de")], 2); + let mut a = GrowableFixedSizeBinary::new(vec![&array], false, 0); + unsafe { + a.extend(0, 1, 2); + } + + let result = a.as_box(); + let result = result + .as_any() + .downcast_ref::() + .unwrap(); + + let expected = FixedSizeBinaryArray::from_iter(vec![Some("bc"), None], 2); + assert_eq!(&expected, result); +} + +/// as_arc +#[test] +fn as_arc() { + let array = + FixedSizeBinaryArray::from_iter(vec![Some(b"ab"), Some(b"bc"), None, Some(b"de")], 2); + let mut a = GrowableFixedSizeBinary::new(vec![&array], false, 0); + unsafe { + a.extend(0, 1, 2); + } + + let result = a.as_arc(); + let result = result + .as_any() + .downcast_ref::() + .unwrap(); + + let expected = FixedSizeBinaryArray::from_iter(vec![Some("bc"), None], 2); + assert_eq!(&expected, result); +} diff --git a/crates/polars/tests/it/arrow/array/growable/fixed_size_list.rs b/crates/polars/tests/it/arrow/array/growable/fixed_size_list.rs new file mode 100644 index 0000000000000..dcdc25d1bda9a --- /dev/null +++ b/crates/polars/tests/it/arrow/array/growable/fixed_size_list.rs @@ -0,0 +1,95 @@ +use arrow::array::growable::{Growable, GrowableFixedSizeList}; +use arrow::array::{ + FixedSizeListArray, MutableFixedSizeListArray, MutablePrimitiveArray, TryExtend, +}; + +fn create_list_array(data: Vec>>>) -> FixedSizeListArray { + let mut array = MutableFixedSizeListArray::new(MutablePrimitiveArray::::new(), 3); + array.try_extend(data).unwrap(); + array.into() +} + +#[test] +fn basic() { + let data = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + Some(vec![Some(4), Some(5), Some(6)]), + Some(vec![Some(7i32), Some(8), Some(9)]), + ]; + + let array = create_list_array(data); + + let mut a = GrowableFixedSizeList::new(vec![&array], false, 0); + unsafe { + a.extend(0, 0, 1); + } + assert_eq!(a.len(), 1); + + let result: FixedSizeListArray = a.into(); + + let expected = vec![Some(vec![Some(1i32), Some(2), Some(3)])]; + let expected = create_list_array(expected); + + assert_eq!(result, expected) +} + +#[test] +fn null_offset() { + let data = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + None, + Some(vec![Some(6i32), Some(7), Some(8)]), + ]; + let array = create_list_array(data); + let array = array.sliced(1, 2); + + let mut a = GrowableFixedSizeList::new(vec![&array], false, 0); + unsafe { + a.extend(0, 1, 1); + } + assert_eq!(a.len(), 1); + + let result: FixedSizeListArray = a.into(); + + let expected = vec![Some(vec![Some(6i32), Some(7), Some(8)])]; + let expected = create_list_array(expected); + + assert_eq!(result, expected) +} + +#[test] +fn test_from_two_lists() { + let data_1 = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + None, + Some(vec![Some(6i32), None, Some(8)]), + ]; + let array_1 = create_list_array(data_1); + + let data_2 = vec![ + Some(vec![Some(8i32), Some(7), Some(6)]), + Some(vec![Some(5i32), None, Some(4)]), + Some(vec![Some(2i32), Some(1), Some(0)]), + ]; + let array_2 = create_list_array(data_2); + + let mut a = GrowableFixedSizeList::new(vec![&array_1, &array_2], false, 6); + unsafe { + a.extend(0, 0, 2); + } + unsafe { + a.extend(1, 1, 1); + } + assert_eq!(a.len(), 3); + + let result: FixedSizeListArray = a.into(); + + let expected = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + None, + Some(vec![Some(5i32), None, Some(4)]), + ]; + let expected = create_list_array(expected); + + assert_eq!(result, expected); +} diff --git a/crates/polars/tests/it/arrow/array/growable/list.rs b/crates/polars/tests/it/arrow/array/growable/list.rs new file mode 100644 index 0000000000000..1bc0985ceb4f7 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/growable/list.rs @@ -0,0 +1,147 @@ +use arrow::array::growable::{Growable, GrowableList}; +use arrow::array::{Array, ListArray, MutableListArray, MutablePrimitiveArray, TryExtend}; +use arrow::datatypes::ArrowDataType; + +fn create_list_array(data: Vec>>>) -> ListArray { + let mut array = MutableListArray::>::new(); + array.try_extend(data).unwrap(); + array.into() +} + +#[test] +fn extension() { + let data = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + Some(vec![Some(4), Some(5)]), + Some(vec![Some(6i32), Some(7), Some(8)]), + ]; + + let array = create_list_array(data); + + let data_type = + ArrowDataType::Extension("ext".to_owned(), Box::new(array.data_type().clone()), None); + let array_ext = ListArray::new( + data_type, + array.offsets().clone(), + array.values().clone(), + array.validity().cloned(), + ); + + let mut a = GrowableList::new(vec![&array_ext], false, 0); + unsafe { + a.extend(0, 0, 1); + } + assert_eq!(a.len(), 1); + + let result: ListArray = a.into(); + assert_eq!(array_ext.data_type(), result.data_type()); +} + +#[test] +fn basic() { + let data = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + Some(vec![Some(4), Some(5)]), + Some(vec![Some(6i32), Some(7), Some(8)]), + ]; + + let array = create_list_array(data); + + let mut a = GrowableList::new(vec![&array], false, 0); + unsafe { + a.extend(0, 0, 1); + } + assert_eq!(a.len(), 1); + + let result: ListArray = a.into(); + + let expected = vec![Some(vec![Some(1i32), Some(2), Some(3)])]; + let expected = create_list_array(expected); + + assert_eq!(result, expected) +} + +#[test] +fn null_offset() { + let data = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + None, + Some(vec![Some(6i32), Some(7), Some(8)]), + ]; + let array = create_list_array(data); + let array = array.sliced(1, 2); + + let mut a = GrowableList::new(vec![&array], false, 0); + unsafe { + a.extend(0, 1, 1); + } + assert_eq!(a.len(), 1); + + let result: ListArray = a.into(); + + let expected = vec![Some(vec![Some(6i32), Some(7), Some(8)])]; + let expected = create_list_array(expected); + + assert_eq!(result, expected) +} + +#[test] +fn null_offsets() { + let data = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + None, + Some(vec![Some(6i32), None, Some(8)]), + ]; + let array = create_list_array(data); + let array = array.sliced(1, 2); + + let mut a = GrowableList::new(vec![&array], false, 0); + unsafe { + a.extend(0, 1, 1); + } + assert_eq!(a.len(), 1); + + let result: ListArray = a.into(); + + let expected = vec![Some(vec![Some(6i32), None, Some(8)])]; + let expected = create_list_array(expected); + + assert_eq!(result, expected) +} + +#[test] +fn test_from_two_lists() { + let data_1 = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + None, + Some(vec![Some(6i32), None, Some(8)]), + ]; + let array_1 = create_list_array(data_1); + + let data_2 = vec![ + Some(vec![Some(8i32), Some(7), Some(6)]), + Some(vec![Some(5i32), None, Some(4)]), + Some(vec![Some(2i32), Some(1), Some(0)]), + ]; + let array_2 = create_list_array(data_2); + + let mut a = GrowableList::new(vec![&array_1, &array_2], false, 6); + unsafe { + a.extend(0, 0, 2); + } + unsafe { + a.extend(1, 1, 1); + } + assert_eq!(a.len(), 3); + + let result: ListArray = a.into(); + + let expected = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + None, + Some(vec![Some(5i32), None, Some(4)]), + ]; + let expected = create_list_array(expected); + + assert_eq!(result, expected); +} diff --git a/crates/polars/tests/it/arrow/array/growable/mod.rs b/crates/polars/tests/it/arrow/array/growable/mod.rs new file mode 100644 index 0000000000000..43496a1e95b1f --- /dev/null +++ b/crates/polars/tests/it/arrow/array/growable/mod.rs @@ -0,0 +1,75 @@ +mod binary; +mod boolean; +mod dictionary; +mod fixed_binary; +mod fixed_size_list; +mod list; +mod null; +mod primitive; +mod struct_; +mod utf8; + +use arrow::array::growable::make_growable; +use arrow::array::*; +use arrow::datatypes::{ArrowDataType, Field}; + +#[test] +fn test_make_growable() { + let array = Int32Array::from_slice([1, 2]); + make_growable(&[&array], false, 2); + + let array = Utf8Array::::from_slice(["a", "aa"]); + make_growable(&[&array], false, 2); + + let array = Utf8Array::::from_slice(["a", "aa"]); + make_growable(&[&array], false, 2); + + let array = BinaryArray::::from_slice([b"a".as_ref(), b"aa".as_ref()]); + make_growable(&[&array], false, 2); + + let array = BinaryArray::::from_slice([b"a".as_ref(), b"aa".as_ref()]); + make_growable(&[&array], false, 2); + + let array = BinaryArray::::from_slice([b"a".as_ref(), b"aa".as_ref()]); + make_growable(&[&array], false, 2); + + let array = FixedSizeBinaryArray::new( + ArrowDataType::FixedSizeBinary(2), + b"abcd".to_vec().into(), + None, + ); + make_growable(&[&array], false, 2); +} + +#[test] +fn test_make_growable_extension() { + let array = DictionaryArray::try_from_keys( + Int32Array::from_slice([1, 0]), + Int32Array::from_slice([1, 2]).boxed(), + ) + .unwrap(); + make_growable(&[&array], false, 2); + + let data_type = + ArrowDataType::Extension("ext".to_owned(), Box::new(ArrowDataType::Int32), None); + let array = Int32Array::from_slice([1, 2]).to(data_type.clone()); + let array_grown = make_growable(&[&array], false, 2).as_box(); + assert_eq!(array_grown.data_type(), &data_type); + + let data_type = ArrowDataType::Extension( + "ext".to_owned(), + Box::new(ArrowDataType::Struct(vec![Field::new( + "a", + ArrowDataType::Int32, + false, + )])), + None, + ); + let array = StructArray::new( + data_type.clone(), + vec![Int32Array::from_slice([1, 2]).boxed()], + None, + ); + let array_grown = make_growable(&[&array], false, 2).as_box(); + assert_eq!(array_grown.data_type(), &data_type); +} diff --git a/crates/polars/tests/it/arrow/array/growable/null.rs b/crates/polars/tests/it/arrow/array/growable/null.rs new file mode 100644 index 0000000000000..2d6a118a117cc --- /dev/null +++ b/crates/polars/tests/it/arrow/array/growable/null.rs @@ -0,0 +1,21 @@ +use arrow::array::growable::{Growable, GrowableNull}; +use arrow::array::NullArray; +use arrow::datatypes::ArrowDataType; + +#[test] +fn null() { + let mut mutable = GrowableNull::default(); + + unsafe { + mutable.extend(0, 1, 2); + } + unsafe { + mutable.extend(1, 0, 1); + } + assert_eq!(mutable.len(), 3); + + let result: NullArray = mutable.into(); + + let expected = NullArray::new(ArrowDataType::Null, 3); + assert_eq!(result, expected); +} diff --git a/crates/polars/tests/it/arrow/array/growable/primitive.rs b/crates/polars/tests/it/arrow/array/growable/primitive.rs new file mode 100644 index 0000000000000..37c105f2c7282 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/growable/primitive.rs @@ -0,0 +1,82 @@ +use arrow::array::growable::{Growable, GrowablePrimitive}; +use arrow::array::PrimitiveArray; + +/// tests extending from a primitive array w/ offset nor nulls +#[test] +fn basics() { + let b = PrimitiveArray::::from(vec![Some(1), Some(2), Some(3)]); + let mut a = GrowablePrimitive::new(vec![&b], false, 3); + unsafe { + a.extend(0, 0, 2); + } + assert_eq!(a.len(), 2); + let result: PrimitiveArray = a.into(); + let expected = PrimitiveArray::::from(vec![Some(1), Some(2)]); + assert_eq!(result, expected); +} + +/// tests extending from a primitive array with offset w/ nulls +#[test] +fn offset() { + let b = PrimitiveArray::::from(vec![Some(1), Some(2), Some(3)]); + let b = b.sliced(1, 2); + let mut a = GrowablePrimitive::new(vec![&b], false, 2); + unsafe { + a.extend(0, 0, 2); + } + assert_eq!(a.len(), 2); + let result: PrimitiveArray = a.into(); + let expected = PrimitiveArray::::from(vec![Some(2), Some(3)]); + assert_eq!(result, expected); +} + +/// tests extending from a primitive array with offset and nulls +#[test] +fn null_offset() { + let b = PrimitiveArray::::from(vec![Some(1), None, Some(3)]); + let b = b.sliced(1, 2); + let mut a = GrowablePrimitive::new(vec![&b], false, 2); + unsafe { + a.extend(0, 0, 2); + } + assert_eq!(a.len(), 2); + let result: PrimitiveArray = a.into(); + let expected = PrimitiveArray::::from(vec![None, Some(3)]); + assert_eq!(result, expected); +} + +#[test] +fn null_offset_validity() { + let b = PrimitiveArray::::from(&[Some(1), Some(2), Some(3)]); + let b = b.sliced(1, 2); + let mut a = GrowablePrimitive::new(vec![&b], true, 2); + unsafe { + a.extend(0, 0, 2); + } + a.extend_validity(3); + unsafe { + a.extend(0, 1, 1); + } + assert_eq!(a.len(), 6); + let result: PrimitiveArray = a.into(); + let expected = PrimitiveArray::::from(&[Some(2), Some(3), None, None, None, Some(3)]); + assert_eq!(result, expected); +} + +#[test] +fn joining_arrays() { + let b = PrimitiveArray::::from(&[Some(1), Some(2), Some(3)]); + let c = PrimitiveArray::::from(&[Some(4), Some(5), Some(6)]); + let mut a = GrowablePrimitive::new(vec![&b, &c], false, 4); + unsafe { + a.extend(0, 0, 2); + } + unsafe { + a.extend(1, 1, 2); + } + assert_eq!(a.len(), 4); + let result: PrimitiveArray = a.into(); + + let expected = PrimitiveArray::::from(&[Some(1), Some(2), Some(5), Some(6)]); + assert_eq!(result, expected); +} diff --git a/crates/polars/tests/it/arrow/array/growable/struct_.rs b/crates/polars/tests/it/arrow/array/growable/struct_.rs new file mode 100644 index 0000000000000..809e70749f091 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/growable/struct_.rs @@ -0,0 +1,139 @@ +use arrow::array::growable::{Growable, GrowableStruct}; +use arrow::array::{Array, PrimitiveArray, StructArray, Utf8Array}; +use arrow::bitmap::Bitmap; +use arrow::datatypes::{ArrowDataType, Field}; + +fn some_values() -> (ArrowDataType, Vec>) { + let strings: Box = Box::new(Utf8Array::::from([ + Some("a"), + Some("aa"), + None, + Some("mark"), + Some("doe"), + ])); + let ints: Box = Box::new(PrimitiveArray::::from(&[ + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + ])); + let fields = vec![ + Field::new("f1", ArrowDataType::Utf8, true), + Field::new("f2", ArrowDataType::Int32, true), + ]; + (ArrowDataType::Struct(fields), vec![strings, ints]) +} + +#[test] +fn basic() { + let (fields, values) = some_values(); + + let array = StructArray::new(fields.clone(), values.clone(), None); + + let mut a = GrowableStruct::new(vec![&array], false, 0); + + unsafe { + a.extend(0, 1, 2); + } + assert_eq!(a.len(), 2); + let result: StructArray = a.into(); + + let expected = StructArray::new( + fields, + vec![values[0].sliced(1, 2), values[1].sliced(1, 2)], + None, + ); + assert_eq!(result, expected) +} + +#[test] +fn offset() { + let (fields, values) = some_values(); + + let array = StructArray::new(fields.clone(), values.clone(), None).sliced(1, 3); + + let mut a = GrowableStruct::new(vec![&array], false, 0); + + unsafe { + a.extend(0, 1, 2); + } + assert_eq!(a.len(), 2); + let result: StructArray = a.into(); + + let expected = StructArray::new( + fields, + vec![values[0].sliced(2, 2), values[1].sliced(2, 2)], + None, + ); + + assert_eq!(result, expected); +} + +#[test] +fn nulls() { + let (fields, values) = some_values(); + + let array = StructArray::new( + fields.clone(), + values.clone(), + Some(Bitmap::from_u8_slice([0b00000010], 5)), + ); + + let mut a = GrowableStruct::new(vec![&array], false, 0); + + unsafe { + a.extend(0, 1, 2); + } + assert_eq!(a.len(), 2); + let result: StructArray = a.into(); + + let expected = StructArray::new( + fields, + vec![values[0].sliced(1, 2), values[1].sliced(1, 2)], + Some(Bitmap::from_u8_slice([0b00000010], 5).sliced(1, 2)), + ); + + assert_eq!(result, expected) +} + +#[test] +fn many() { + let (fields, values) = some_values(); + + let array = StructArray::new(fields.clone(), values.clone(), None); + + let mut mutable = GrowableStruct::new(vec![&array, &array], true, 0); + + unsafe { + mutable.extend(0, 1, 2); + } + unsafe { + mutable.extend(1, 0, 2); + } + mutable.extend_validity(1); + assert_eq!(mutable.len(), 5); + let result = mutable.as_box(); + + let expected_string: Box = Box::new(Utf8Array::::from([ + Some("aa"), + None, + Some("a"), + Some("aa"), + None, + ])); + let expected_int: Box = Box::new(PrimitiveArray::::from(vec![ + Some(2), + Some(3), + Some(1), + Some(2), + None, + ])); + + let expected = StructArray::new( + fields, + vec![expected_string, expected_int], + Some(Bitmap::from([true, true, true, true, false])), + ); + assert_eq!(expected, result.as_ref()) +} diff --git a/crates/polars/tests/it/arrow/array/growable/utf8.rs b/crates/polars/tests/it/arrow/array/growable/utf8.rs new file mode 100644 index 0000000000000..af2be2ab98674 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/growable/utf8.rs @@ -0,0 +1,97 @@ +use arrow::array::growable::{Growable, GrowableUtf8}; +use arrow::array::Utf8Array; + +/// tests extending from a variable-sized (strings and binary) array w/ offset with nulls +#[test] +fn validity() { + let array = Utf8Array::::from([Some("a"), Some("bc"), None, Some("defh")]); + + let mut a = GrowableUtf8::new(vec![&array], false, 0); + + unsafe { + a.extend(0, 1, 2); + } + + let result: Utf8Array = a.into(); + + let expected = Utf8Array::::from([Some("bc"), None]); + assert_eq!(result, expected); +} + +/// tests extending from a variable-sized (strings and binary) array +/// with an offset and nulls +#[test] +fn offsets() { + let array = Utf8Array::::from([Some("a"), Some("bc"), None, Some("defh")]); + let array = array.sliced(1, 3); + + let mut a = GrowableUtf8::new(vec![&array], false, 0); + + unsafe { + a.extend(0, 0, 3); + } + assert_eq!(a.len(), 3); + + let result: Utf8Array = a.into(); + + let expected = Utf8Array::::from([Some("bc"), None, Some("defh")]); + assert_eq!(result, expected); +} + +#[test] +fn offsets2() { + let array = Utf8Array::::from([Some("a"), Some("bc"), None, Some("defh")]); + let array = array.sliced(1, 3); + + let mut a = GrowableUtf8::new(vec![&array], false, 0); + + unsafe { + a.extend(0, 0, 3); + } + assert_eq!(a.len(), 3); + + let result: Utf8Array = a.into(); + + let expected = Utf8Array::::from([Some("bc"), None, Some("defh")]); + assert_eq!(result, expected); +} + +#[test] +fn multiple_with_validity() { + let array1 = Utf8Array::::from_slice(["hello", "world"]); + let array2 = Utf8Array::::from([Some("1"), None]); + + let mut a = GrowableUtf8::new(vec![&array1, &array2], false, 5); + + unsafe { + a.extend(0, 0, 2); + } + unsafe { + a.extend(1, 0, 2); + } + assert_eq!(a.len(), 4); + + let result: Utf8Array = a.into(); + + let expected = Utf8Array::::from([Some("hello"), Some("world"), Some("1"), None]); + assert_eq!(result, expected); +} + +#[test] +fn null_offset_validity() { + let array = Utf8Array::::from([Some("a"), Some("bc"), None, Some("defh")]); + let array = array.sliced(1, 3); + + let mut a = GrowableUtf8::new(vec![&array], true, 0); + + unsafe { + a.extend(0, 1, 2); + } + a.extend_validity(1); + assert_eq!(a.len(), 3); + + let result: Utf8Array = a.into(); + + let expected = Utf8Array::::from([None, Some("defh"), None]); + assert_eq!(result, expected); +} diff --git a/crates/polars/tests/it/arrow/array/list/mod.rs b/crates/polars/tests/it/arrow/array/list/mod.rs new file mode 100644 index 0000000000000..77e443781c177 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/list/mod.rs @@ -0,0 +1,70 @@ +use arrow::array::*; +use arrow::buffer::Buffer; +use arrow::datatypes::ArrowDataType; + +mod mutable; + +#[test] +fn debug() { + let values = Buffer::from(vec![1, 2, 3, 4, 5]); + let values = PrimitiveArray::::new(ArrowDataType::Int32, values, None); + + let data_type = ListArray::::default_datatype(ArrowDataType::Int32); + let array = ListArray::::new( + data_type, + vec![0, 2, 2, 3, 5].try_into().unwrap(), + Box::new(values), + None, + ); + + assert_eq!(format!("{array:?}"), "ListArray[[1, 2], [], [3], [4, 5]]"); +} + +#[test] +#[should_panic] +fn test_nested_panic() { + let values = Buffer::from(vec![1, 2, 3, 4, 5]); + let values = PrimitiveArray::::new(ArrowDataType::Int32, values, None); + + let data_type = ListArray::::default_datatype(ArrowDataType::Int32); + let array = ListArray::::new( + data_type.clone(), + vec![0, 2, 2, 3, 5].try_into().unwrap(), + Box::new(values), + None, + ); + + // The datatype for the nested array has to be created considering + // the nested structure of the child data + let _ = ListArray::::new( + data_type, + vec![0, 2, 4].try_into().unwrap(), + Box::new(array), + None, + ); +} + +#[test] +fn test_nested_display() { + let values = Buffer::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + let values = PrimitiveArray::::new(ArrowDataType::Int32, values, None); + + let data_type = ListArray::::default_datatype(ArrowDataType::Int32); + let array = ListArray::::new( + data_type, + vec![0, 2, 4, 7, 7, 8, 10].try_into().unwrap(), + Box::new(values), + None, + ); + + let data_type = ListArray::::default_datatype(array.data_type().clone()); + let nested = ListArray::::new( + data_type, + vec![0, 2, 5, 6].try_into().unwrap(), + Box::new(array), + None, + ); + + let expected = "ListArray[[[1, 2], [3, 4]], [[5, 6, 7], [], [8]], [[9, 10]]]"; + assert_eq!(format!("{nested:?}"), expected); +} diff --git a/crates/polars/tests/it/arrow/array/list/mutable.rs b/crates/polars/tests/it/arrow/array/list/mutable.rs new file mode 100644 index 0000000000000..2d4ba0c4d2f11 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/list/mutable.rs @@ -0,0 +1,76 @@ +use arrow::array::*; +use arrow::bitmap::Bitmap; +use arrow::buffer::Buffer; +use arrow::datatypes::ArrowDataType; + +#[test] +fn basics() { + let data = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + None, + Some(vec![Some(4), None, Some(6)]), + ]; + + let mut array = MutableListArray::>::new(); + array.try_extend(data).unwrap(); + let array: ListArray = array.into(); + + let values = PrimitiveArray::::new( + ArrowDataType::Int32, + Buffer::from(vec![1, 2, 3, 4, 0, 6]), + Some(Bitmap::from([true, true, true, true, false, true])), + ); + + let data_type = ListArray::::default_datatype(ArrowDataType::Int32); + let expected = ListArray::::new( + data_type, + vec![0, 3, 3, 6].try_into().unwrap(), + Box::new(values), + Some(Bitmap::from([true, false, true])), + ); + assert_eq!(expected, array); +} + +#[test] +fn with_capacity() { + let array = MutableListArray::>::with_capacity(10); + assert!(array.offsets().capacity() >= 10); + assert_eq!(array.offsets().len_proxy(), 0); + assert_eq!(array.values().values().capacity(), 0); + assert_eq!(array.validity(), None); +} + +#[test] +fn push() { + let mut array = MutableListArray::>::new(); + array + .try_push(Some(vec![Some(1i32), Some(2), Some(3)])) + .unwrap(); + assert_eq!(array.len(), 1); + assert_eq!(array.values().values().as_ref(), [1, 2, 3]); + assert_eq!(array.offsets().as_slice(), [0, 3]); + assert_eq!(array.validity(), None); +} + +#[test] +fn extend_from_self() { + let data = vec![ + Some(vec![Some(1i32), Some(2), Some(3)]), + None, + Some(vec![Some(4), None, Some(6)]), + ]; + let mut a = MutableListArray::>::new(); + a.try_extend(data.clone()).unwrap(); + + a.try_extend_from_self(&a.clone()).unwrap(); + let a: ListArray = a.into(); + + let mut expected = data.clone(); + expected.extend(data); + + let mut b = MutableListArray::>::new(); + b.try_extend(expected).unwrap(); + let b: ListArray = b.into(); + + assert_eq!(a, b); +} diff --git a/crates/polars/tests/it/arrow/array/map/mod.rs b/crates/polars/tests/it/arrow/array/map/mod.rs new file mode 100644 index 0000000000000..30d1d263a9d7d --- /dev/null +++ b/crates/polars/tests/it/arrow/array/map/mod.rs @@ -0,0 +1,52 @@ +use arrow::array::*; +use arrow::datatypes::{ArrowDataType, Field}; + +#[test] +fn basics() { + let dt = ArrowDataType::Struct(vec![ + Field::new("a", ArrowDataType::Utf8, true), + Field::new("b", ArrowDataType::Utf8, true), + ]); + let data_type = ArrowDataType::Map(Box::new(Field::new("a", dt.clone(), true)), false); + + let field = StructArray::new( + dt.clone(), + vec![ + Box::new(Utf8Array::::from_slice(["a", "aa", "aaa"])) as _, + Box::new(Utf8Array::::from_slice(["b", "bb", "bbb"])), + ], + None, + ); + + let array = MapArray::new( + data_type, + vec![0, 1, 2].try_into().unwrap(), + Box::new(field), + None, + ); + + assert_eq!( + array.value(0), + Box::new(StructArray::new( + dt.clone(), + vec![ + Box::new(Utf8Array::::from_slice(["a"])) as _, + Box::new(Utf8Array::::from_slice(["b"])), + ], + None, + )) as Box + ); + + let sliced = array.sliced(1, 1); + assert_eq!( + sliced.value(0), + Box::new(StructArray::new( + dt, + vec![ + Box::new(Utf8Array::::from_slice(["aa"])) as _, + Box::new(Utf8Array::::from_slice(["bb"])), + ], + None, + )) as Box + ); +} diff --git a/crates/polars/tests/it/arrow/array/mod.rs b/crates/polars/tests/it/arrow/array/mod.rs new file mode 100644 index 0000000000000..89fbe3f19ad5a --- /dev/null +++ b/crates/polars/tests/it/arrow/array/mod.rs @@ -0,0 +1,141 @@ +mod binary; +mod boolean; +mod dictionary; +mod equal; +mod fixed_size_binary; +mod fixed_size_list; +mod growable; +mod list; +mod map; +mod primitive; +mod struct_; +mod union; +mod utf8; + +use arrow::array::{clone, new_empty_array, new_null_array, Array, PrimitiveArray}; +use arrow::bitmap::Bitmap; +use arrow::datatypes::{ArrowDataType, Field, UnionMode}; + +#[test] +fn nulls() { + let datatypes = vec![ + ArrowDataType::Int32, + ArrowDataType::Float64, + ArrowDataType::Utf8, + ArrowDataType::Binary, + ArrowDataType::List(Box::new(Field::new("a", ArrowDataType::Binary, true))), + ]; + let a = datatypes + .into_iter() + .all(|x| new_null_array(x, 10).null_count() == 10); + assert!(a); + + // unions' null count is always 0 + let datatypes = vec![ + ArrowDataType::Union( + vec![Field::new("a", ArrowDataType::Binary, true)], + None, + UnionMode::Dense, + ), + ArrowDataType::Union( + vec![Field::new("a", ArrowDataType::Binary, true)], + None, + UnionMode::Sparse, + ), + ]; + let a = datatypes + .into_iter() + .all(|x| new_null_array(x, 10).null_count() == 0); + assert!(a); +} + +#[test] +fn empty() { + let datatypes = vec![ + ArrowDataType::Int32, + ArrowDataType::Float64, + ArrowDataType::Utf8, + ArrowDataType::Binary, + ArrowDataType::List(Box::new(Field::new("a", ArrowDataType::Binary, true))), + ArrowDataType::List(Box::new(Field::new( + "a", + ArrowDataType::Extension("ext".to_owned(), Box::new(ArrowDataType::Int32), None), + true, + ))), + ArrowDataType::Union( + vec![Field::new("a", ArrowDataType::Binary, true)], + None, + UnionMode::Sparse, + ), + ArrowDataType::Union( + vec![Field::new("a", ArrowDataType::Binary, true)], + None, + UnionMode::Dense, + ), + ArrowDataType::Struct(vec![Field::new("a", ArrowDataType::Int32, true)]), + ]; + let a = datatypes.into_iter().all(|x| new_empty_array(x).len() == 0); + assert!(a); +} + +#[test] +fn empty_extension() { + let datatypes = vec![ + ArrowDataType::Int32, + ArrowDataType::Float64, + ArrowDataType::Utf8, + ArrowDataType::Binary, + ArrowDataType::List(Box::new(Field::new("a", ArrowDataType::Binary, true))), + ArrowDataType::Union( + vec![Field::new("a", ArrowDataType::Binary, true)], + None, + UnionMode::Sparse, + ), + ArrowDataType::Union( + vec![Field::new("a", ArrowDataType::Binary, true)], + None, + UnionMode::Dense, + ), + ArrowDataType::Struct(vec![Field::new("a", ArrowDataType::Int32, true)]), + ]; + let a = datatypes + .into_iter() + .map(|dt| ArrowDataType::Extension("ext".to_owned(), Box::new(dt), None)) + .all(|x| { + let a = new_empty_array(x); + a.len() == 0 && matches!(a.data_type(), ArrowDataType::Extension(_, _, _)) + }); + assert!(a); +} + +#[test] +fn test_clone() { + let datatypes = vec![ + ArrowDataType::Int32, + ArrowDataType::Float64, + ArrowDataType::Utf8, + ArrowDataType::Binary, + ArrowDataType::List(Box::new(Field::new("a", ArrowDataType::Binary, true))), + ]; + let a = datatypes + .into_iter() + .all(|x| clone(new_null_array(x.clone(), 10).as_ref()) == new_null_array(x, 10)); + assert!(a); +} + +#[test] +fn test_with_validity() { + let arr = PrimitiveArray::from_slice([1i32, 2, 3]); + let validity = Bitmap::from(&[true, false, true]); + let arr = arr.with_validity(Some(validity)); + let arr_ref = arr.as_any().downcast_ref::>().unwrap(); + + let expected = PrimitiveArray::from(&[Some(1i32), None, Some(3)]); + assert_eq!(arr_ref, &expected); +} + +// check that we ca derive stuff +#[derive(PartialEq, Clone, Debug)] +struct A { + array: Box, +} diff --git a/crates/polars/tests/it/arrow/array/primitive/fmt.rs b/crates/polars/tests/it/arrow/array/primitive/fmt.rs new file mode 100644 index 0000000000000..6ab0ffa1ee8bb --- /dev/null +++ b/crates/polars/tests/it/arrow/array/primitive/fmt.rs @@ -0,0 +1,224 @@ +use arrow::array::*; +use arrow::datatypes::*; +use arrow::types::{days_ms, months_days_ns}; + +#[test] +fn debug_int32() { + let array = Int32Array::from(&[Some(1), None, Some(2)]); + assert_eq!(format!("{array:?}"), "Int32[1, None, 2]"); +} + +#[test] +fn debug_date32() { + let array = Int32Array::from(&[Some(1), None, Some(2)]).to(ArrowDataType::Date32); + assert_eq!(format!("{array:?}"), "Date32[1970-01-02, None, 1970-01-03]"); +} + +#[test] +fn debug_time32s() { + let array = + Int32Array::from(&[Some(1), None, Some(2)]).to(ArrowDataType::Time32(TimeUnit::Second)); + assert_eq!( + format!("{array:?}"), + "Time32(Second)[00:00:01, None, 00:00:02]" + ); +} + +#[test] +fn debug_time32ms() { + let array = Int32Array::from(&[Some(1), None, Some(2)]) + .to(ArrowDataType::Time32(TimeUnit::Millisecond)); + assert_eq!( + format!("{array:?}"), + "Time32(Millisecond)[00:00:00.001, None, 00:00:00.002]" + ); +} + +#[test] +fn debug_interval_d() { + let array = Int32Array::from(&[Some(1), None, Some(2)]) + .to(ArrowDataType::Interval(IntervalUnit::YearMonth)); + assert_eq!(format!("{array:?}"), "Interval(YearMonth)[1m, None, 2m]"); +} + +#[test] +fn debug_int64() { + let array = Int64Array::from(&[Some(1), None, Some(2)]).to(ArrowDataType::Int64); + assert_eq!(format!("{array:?}"), "Int64[1, None, 2]"); +} + +#[test] +fn debug_date64() { + let array = Int64Array::from(&[Some(1), None, Some(86400000)]).to(ArrowDataType::Date64); + assert_eq!(format!("{array:?}"), "Date64[1970-01-01, None, 1970-01-02]"); +} + +#[test] +fn debug_time64us() { + let array = Int64Array::from(&[Some(1), None, Some(2)]) + .to(ArrowDataType::Time64(TimeUnit::Microsecond)); + assert_eq!( + format!("{array:?}"), + "Time64(Microsecond)[00:00:00.000001, None, 00:00:00.000002]" + ); +} + +#[test] +fn debug_time64ns() { + let array = + Int64Array::from(&[Some(1), None, Some(2)]).to(ArrowDataType::Time64(TimeUnit::Nanosecond)); + assert_eq!( + format!("{array:?}"), + "Time64(Nanosecond)[00:00:00.000000001, None, 00:00:00.000000002]" + ); +} + +#[test] +fn debug_timestamp_s() { + let array = Int64Array::from(&[Some(1), None, Some(2)]) + .to(ArrowDataType::Timestamp(TimeUnit::Second, None)); + assert_eq!( + format!("{array:?}"), + "Timestamp(Second, None)[1970-01-01 00:00:01, None, 1970-01-01 00:00:02]" + ); +} + +#[test] +fn debug_timestamp_ms() { + let array = Int64Array::from(&[Some(1), None, Some(2)]) + .to(ArrowDataType::Timestamp(TimeUnit::Millisecond, None)); + assert_eq!( + format!("{array:?}"), + "Timestamp(Millisecond, None)[1970-01-01 00:00:00.001, None, 1970-01-01 00:00:00.002]" + ); +} + +#[test] +fn debug_timestamp_us() { + let array = Int64Array::from(&[Some(1), None, Some(2)]) + .to(ArrowDataType::Timestamp(TimeUnit::Microsecond, None)); + assert_eq!( + format!("{array:?}"), + "Timestamp(Microsecond, None)[1970-01-01 00:00:00.000001, None, 1970-01-01 00:00:00.000002]" + ); +} + +#[test] +fn debug_timestamp_ns() { + let array = Int64Array::from(&[Some(1), None, Some(2)]) + .to(ArrowDataType::Timestamp(TimeUnit::Nanosecond, None)); + assert_eq!( + format!("{array:?}"), + "Timestamp(Nanosecond, None)[1970-01-01 00:00:00.000000001, None, 1970-01-01 00:00:00.000000002]" + ); +} + +#[test] +fn debug_timestamp_tz_ns() { + let array = Int64Array::from(&[Some(1), None, Some(2)]).to(ArrowDataType::Timestamp( + TimeUnit::Nanosecond, + Some("+02:00".to_string()), + )); + assert_eq!( + format!("{array:?}"), + "Timestamp(Nanosecond, Some(\"+02:00\"))[1970-01-01 02:00:00.000000001 +02:00, None, 1970-01-01 02:00:00.000000002 +02:00]" + ); +} + +#[test] +fn debug_timestamp_tz_not_parsable() { + let array = Int64Array::from(&[Some(1), None, Some(2)]).to(ArrowDataType::Timestamp( + TimeUnit::Nanosecond, + Some("aa".to_string()), + )); + assert_eq!( + format!("{array:?}"), + "Timestamp(Nanosecond, Some(\"aa\"))[1 (aa), None, 2 (aa)]" + ); +} + +#[cfg(feature = "chrono-tz")] +#[test] +fn debug_timestamp_tz1_ns() { + let array = Int64Array::from(&[Some(1), None, Some(2)]).to(ArrowDataType::Timestamp( + TimeUnit::Nanosecond, + Some("Europe/Lisbon".to_string()), + )); + assert_eq!( + format!("{array:?}"), + "Timestamp(Nanosecond, Some(\"Europe/Lisbon\"))[1970-01-01 01:00:00.000000001 CET, None, 1970-01-01 01:00:00.000000002 CET]" + ); +} + +#[test] +fn debug_duration_ms() { + let array = Int64Array::from(&[Some(1), None, Some(2)]) + .to(ArrowDataType::Duration(TimeUnit::Millisecond)); + assert_eq!( + format!("{array:?}"), + "Duration(Millisecond)[1ms, None, 2ms]" + ); +} + +#[test] +fn debug_duration_s() { + let array = + Int64Array::from(&[Some(1), None, Some(2)]).to(ArrowDataType::Duration(TimeUnit::Second)); + assert_eq!(format!("{array:?}"), "Duration(Second)[1s, None, 2s]"); +} + +#[test] +fn debug_duration_us() { + let array = Int64Array::from(&[Some(1), None, Some(2)]) + .to(ArrowDataType::Duration(TimeUnit::Microsecond)); + assert_eq!( + format!("{array:?}"), + "Duration(Microsecond)[1us, None, 2us]" + ); +} + +#[test] +fn debug_duration_ns() { + let array = Int64Array::from(&[Some(1), None, Some(2)]) + .to(ArrowDataType::Duration(TimeUnit::Nanosecond)); + assert_eq!(format!("{array:?}"), "Duration(Nanosecond)[1ns, None, 2ns]"); +} + +#[test] +fn debug_decimal() { + let array = + Int128Array::from(&[Some(12345), None, Some(23456)]).to(ArrowDataType::Decimal(5, 2)); + assert_eq!(format!("{array:?}"), "Decimal(5, 2)[123.45, None, 234.56]"); +} + +#[test] +fn debug_decimal1() { + let array = + Int128Array::from(&[Some(12345), None, Some(23456)]).to(ArrowDataType::Decimal(5, 1)); + assert_eq!(format!("{array:?}"), "Decimal(5, 1)[1234.5, None, 2345.6]"); +} + +#[test] +fn debug_interval_days_ms() { + let array = DaysMsArray::from(&[Some(days_ms::new(1, 1)), None, Some(days_ms::new(2, 2))]); + assert_eq!( + format!("{array:?}"), + "Interval(DayTime)[1d1ms, None, 2d2ms]" + ); +} + +#[test] +fn debug_months_days_ns() { + let data = &[ + Some(months_days_ns::new(1, 1, 2)), + None, + Some(months_days_ns::new(2, 3, 3)), + ]; + + let array = MonthsDaysNsArray::from(&data); + + assert_eq!( + format!("{array:?}"), + "Interval(MonthDayNano)[1m1d2ns, None, 2m3d3ns]" + ); +} diff --git a/crates/polars/tests/it/arrow/array/primitive/mod.rs b/crates/polars/tests/it/arrow/array/primitive/mod.rs new file mode 100644 index 0000000000000..e36b68f5a6a74 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/primitive/mod.rs @@ -0,0 +1,140 @@ +use arrow::array::*; +use arrow::bitmap::Bitmap; +use arrow::buffer::Buffer; +use arrow::datatypes::*; +use arrow::types::months_days_ns; + +mod fmt; +mod mutable; +mod to_mutable; + +#[test] +fn basics() { + let data = vec![Some(1), None, Some(10)]; + + let array = Int32Array::from_iter(data); + + assert_eq!(array.value(0), 1); + assert_eq!(array.value(1), 0); + assert_eq!(array.value(2), 10); + assert_eq!(array.values().as_slice(), &[1, 0, 10]); + assert_eq!( + array.validity(), + Some(&Bitmap::from_u8_slice([0b00000101], 3)) + ); + assert!(array.is_valid(0)); + assert!(!array.is_valid(1)); + assert!(array.is_valid(2)); + + let array2 = Int32Array::new( + ArrowDataType::Int32, + array.values().clone(), + array.validity().cloned(), + ); + assert_eq!(array, array2); + + let array = array.sliced(1, 2); + assert_eq!(array.value(0), 0); + assert_eq!(array.value(1), 10); + assert_eq!(array.values().as_slice(), &[0, 10]); + + unsafe { + assert_eq!(array.value_unchecked(0), 0); + assert_eq!(array.value_unchecked(1), 10); + } +} + +#[test] +fn empty() { + let array = Int32Array::new_empty(ArrowDataType::Int32); + assert_eq!(array.values().len(), 0); + assert_eq!(array.validity(), None); +} + +#[test] +fn from() { + let data = vec![Some(1), None, Some(10)]; + + let array = PrimitiveArray::from(data.clone()); + assert_eq!(array.len(), 3); + + let array = PrimitiveArray::from_iter(data.clone()); + assert_eq!(array.len(), 3); + + let array = PrimitiveArray::from_trusted_len_iter(data.into_iter()); + assert_eq!(array.len(), 3); + + let data = vec![1i32, 2, 3]; + + let array = PrimitiveArray::from_values(data.clone()); + assert_eq!(array.len(), 3); + + let array = PrimitiveArray::from_trusted_len_values_iter(data.into_iter()); + assert_eq!(array.len(), 3); +} + +#[test] +fn months_days_ns_from_slice() { + let data = &[ + months_days_ns::new(1, 1, 2), + months_days_ns::new(1, 1, 3), + months_days_ns::new(2, 3, 3), + ]; + + let array = MonthsDaysNsArray::from_slice(data); + + let a = array.values().as_slice(); + assert_eq!(a, data.as_ref()); +} + +#[test] +fn wrong_data_type() { + let values = Buffer::from(b"abbb".to_vec()); + assert!(PrimitiveArray::try_new(ArrowDataType::Utf8, values, None).is_err()); +} + +#[test] +fn wrong_len() { + let values = Buffer::from(b"abbb".to_vec()); + let validity = Some([true, false].into()); + assert!(PrimitiveArray::try_new(ArrowDataType::Utf8, values, validity).is_err()); +} + +#[test] +fn into_mut_1() { + let values = Buffer::::from(vec![0, 1]); + let a = values.clone(); // cloned values + assert_eq!(a, values); + let array = PrimitiveArray::new(ArrowDataType::Int32, values, None); + assert!(array.into_mut().is_left()); +} + +#[test] +fn into_mut_2() { + let values = Buffer::::from(vec![0, 1]); + let validity = Some([true, false].into()); + let a = validity.clone(); // cloned values + assert_eq!(a, validity); + let array = PrimitiveArray::new(ArrowDataType::Int32, values, validity); + assert!(array.into_mut().is_left()); +} + +#[test] +fn into_mut_3() { + let values = Buffer::::from(vec![0, 1]); + let validity = Some([true, false].into()); + let array = PrimitiveArray::new(ArrowDataType::Int32, values, validity); + assert!(array.into_mut().is_right()); +} + +#[test] +fn into_iter() { + let data = vec![Some(1), None, Some(10)]; + let rev = data.clone().into_iter().rev(); + + let array: Int32Array = data.clone().into_iter().collect(); + + assert_eq!(array.clone().into_iter().collect::>(), data); + + assert!(array.into_iter().rev().eq(rev)) +} diff --git a/crates/polars/tests/it/arrow/array/primitive/mutable.rs b/crates/polars/tests/it/arrow/array/primitive/mutable.rs new file mode 100644 index 0000000000000..bd4d3831dc824 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/primitive/mutable.rs @@ -0,0 +1,328 @@ +use arrow::array::*; +use arrow::bitmap::{Bitmap, MutableBitmap}; +use arrow::datatypes::ArrowDataType; +use polars_error::PolarsResult; + +#[test] +fn from_and_into_data() { + let a = MutablePrimitiveArray::try_new( + ArrowDataType::Int32, + vec![1i32, 0], + Some(MutableBitmap::from([true, false])), + ) + .unwrap(); + assert_eq!(a.len(), 2); + let (a, b, c) = a.into_inner(); + assert_eq!(a, ArrowDataType::Int32); + assert_eq!(b, Vec::from([1i32, 0])); + assert_eq!(c, Some(MutableBitmap::from([true, false]))); +} + +#[test] +fn from_vec() { + let a = MutablePrimitiveArray::from_vec(Vec::from([1i32, 0])); + assert_eq!(a.len(), 2); +} + +#[test] +fn to() { + let a = MutablePrimitiveArray::try_new( + ArrowDataType::Int32, + vec![1i32, 0], + Some(MutableBitmap::from([true, false])), + ) + .unwrap(); + let a = a.to(ArrowDataType::Date32); + assert_eq!(a.data_type(), &ArrowDataType::Date32); +} + +#[test] +fn values_mut_slice() { + let mut a = MutablePrimitiveArray::try_new( + ArrowDataType::Int32, + vec![1i32, 0], + Some(MutableBitmap::from([true, false])), + ) + .unwrap(); + let values = a.values_mut_slice(); + + values[0] = 10; + assert_eq!(a.values()[0], 10); +} + +#[test] +fn push() { + let mut a = MutablePrimitiveArray::::new(); + a.push(Some(1)); + a.push(None); + a.push_null(); + assert_eq!(a.len(), 3); + assert!(a.is_valid(0)); + assert!(!a.is_valid(1)); + assert!(!a.is_valid(2)); + + assert_eq!(a.values(), &Vec::from([1, 0, 0])); +} + +#[test] +fn pop() { + let mut a = MutablePrimitiveArray::::new(); + a.push(Some(1)); + a.push(None); + a.push(Some(2)); + a.push_null(); + assert_eq!(a.pop(), None); + assert_eq!(a.pop(), Some(2)); + assert_eq!(a.pop(), None); + assert!(a.is_valid(0)); + assert_eq!(a.values(), &Vec::from([1])); + assert_eq!(a.pop(), Some(1)); + assert_eq!(a.len(), 0); + assert_eq!(a.pop(), None); + assert_eq!(a.len(), 0); +} + +#[test] +fn pop_all_some() { + let mut a = MutablePrimitiveArray::::new(); + for v in 0..8 { + a.push(Some(v)); + } + + a.push(Some(8)); + assert_eq!(a.pop(), Some(8)); + assert_eq!(a.pop(), Some(7)); + assert_eq!(a.pop(), Some(6)); + assert_eq!(a.pop(), Some(5)); + assert_eq!(a.pop(), Some(4)); + assert_eq!(a.len(), 4); + assert!(a.is_valid(0)); + assert!(a.is_valid(1)); + assert!(a.is_valid(2)); + assert!(a.is_valid(3)); + assert_eq!(a.values(), &Vec::from([0, 1, 2, 3])); +} + +#[test] +fn set() { + let mut a = MutablePrimitiveArray::::from([Some(1), None]); + + a.set(0, Some(2)); + a.set(1, Some(1)); + + assert_eq!(a.len(), 2); + assert!(a.is_valid(0)); + assert!(a.is_valid(1)); + + assert_eq!(a.values(), &Vec::from([2, 1])); + + let mut a = MutablePrimitiveArray::::from_slice([1, 2]); + + a.set(0, Some(2)); + a.set(1, None); + + assert_eq!(a.len(), 2); + assert!(a.is_valid(0)); + assert!(!a.is_valid(1)); + + assert_eq!(a.values(), &Vec::from([2, 0])); +} + +#[test] +fn from_iter() { + let a = MutablePrimitiveArray::::from_iter((0..2).map(Some)); + assert_eq!(a.len(), 2); + let validity = a.validity().unwrap(); + assert_eq!(validity.unset_bits(), 0); +} + +#[test] +fn natural_arc() { + let a = MutablePrimitiveArray::::from_slice([0, 1]).into_arc(); + assert_eq!(a.len(), 2); +} + +#[test] +fn as_arc() { + let a = MutablePrimitiveArray::::from_slice([0, 1]).as_arc(); + assert_eq!(a.len(), 2); +} + +#[test] +fn as_box() { + let a = MutablePrimitiveArray::::from_slice([0, 1]).as_box(); + assert_eq!(a.len(), 2); +} + +#[test] +fn shrink_to_fit_and_capacity() { + let mut a = MutablePrimitiveArray::::with_capacity(100); + a.push(Some(1)); + a.try_push(None).unwrap(); + assert!(a.capacity() >= 100); + (&mut a as &mut dyn MutableArray).shrink_to_fit(); + assert_eq!(a.capacity(), 2); +} + +#[test] +fn only_nulls() { + let mut a = MutablePrimitiveArray::::new(); + a.push(None); + a.push(None); + let a: PrimitiveArray = a.into(); + assert_eq!(a.validity(), Some(&Bitmap::from([false, false]))); +} + +#[test] +fn from_trusted_len() { + let a = + MutablePrimitiveArray::::from_trusted_len_iter(vec![Some(1), None, None].into_iter()); + let a: PrimitiveArray = a.into(); + assert_eq!(a.validity(), Some(&Bitmap::from([true, false, false]))); + + let a = unsafe { + MutablePrimitiveArray::::from_trusted_len_iter_unchecked( + vec![Some(1), None].into_iter(), + ) + }; + let a: PrimitiveArray = a.into(); + assert_eq!(a.validity(), Some(&Bitmap::from([true, false]))); +} + +#[test] +fn extend_trusted_len() { + let mut a = MutablePrimitiveArray::::new(); + a.extend_trusted_len(vec![Some(1), Some(2)].into_iter()); + let validity = a.validity().unwrap(); + assert_eq!(validity.unset_bits(), 0); + a.extend_trusted_len(vec![None, Some(4)].into_iter()); + assert_eq!( + a.validity(), + Some(&MutableBitmap::from([true, true, false, true])) + ); + assert_eq!(a.values(), &Vec::::from([1, 2, 0, 4])); +} + +#[test] +fn extend_constant_no_validity() { + let mut a = MutablePrimitiveArray::::new(); + a.push(Some(1)); + a.extend_constant(2, Some(3)); + assert_eq!(a.validity(), None); + assert_eq!(a.values(), &Vec::::from([1, 3, 3])); +} + +#[test] +fn extend_constant_validity() { + let mut a = MutablePrimitiveArray::::new(); + a.push(Some(1)); + a.extend_constant(2, None); + assert_eq!( + a.validity(), + Some(&MutableBitmap::from([true, false, false])) + ); + assert_eq!(a.values(), &Vec::::from([1, 0, 0])); +} + +#[test] +fn extend_constant_validity_inverse() { + let mut a = MutablePrimitiveArray::::new(); + a.push(None); + a.extend_constant(2, Some(1)); + assert_eq!( + a.validity(), + Some(&MutableBitmap::from([false, true, true])) + ); + assert_eq!(a.values(), &Vec::::from([0, 1, 1])); +} + +#[test] +fn extend_constant_validity_none() { + let mut a = MutablePrimitiveArray::::new(); + a.push(None); + a.extend_constant(2, None); + assert_eq!( + a.validity(), + Some(&MutableBitmap::from([false, false, false])) + ); + assert_eq!(a.values(), &Vec::::from([0, 0, 0])); +} + +#[test] +fn extend_trusted_len_values() { + let mut a = MutablePrimitiveArray::::new(); + a.extend_trusted_len_values(vec![1, 2, 3].into_iter()); + assert_eq!(a.validity(), None); + assert_eq!(a.values(), &Vec::::from([1, 2, 3])); + + let mut a = MutablePrimitiveArray::::new(); + a.push(None); + a.extend_trusted_len_values(vec![1, 2].into_iter()); + assert_eq!( + a.validity(), + Some(&MutableBitmap::from([false, true, true])) + ); +} + +#[test] +fn extend_from_slice() { + let mut a = MutablePrimitiveArray::::new(); + a.extend_from_slice(&[1, 2, 3]); + assert_eq!(a.validity(), None); + assert_eq!(a.values(), &Vec::::from([1, 2, 3])); + + let mut a = MutablePrimitiveArray::::new(); + a.push(None); + a.extend_from_slice(&[1, 2]); + assert_eq!( + a.validity(), + Some(&MutableBitmap::from([false, true, true])) + ); +} + +#[test] +fn set_validity() { + let mut a = MutablePrimitiveArray::::new(); + a.extend_trusted_len(vec![Some(1), Some(2)].into_iter()); + let validity = a.validity().unwrap(); + assert_eq!(validity.unset_bits(), 0); + + // test that upon conversion to array the bitmap is set to None + let arr: PrimitiveArray<_> = a.clone().into(); + assert_eq!(arr.validity(), None); + + // test set_validity + a.set_validity(Some(MutableBitmap::from([false, true]))); + assert_eq!(a.validity(), Some(&MutableBitmap::from([false, true]))); +} + +#[test] +fn set_values() { + let mut a = MutablePrimitiveArray::::from_slice([1, 2]); + a.set_values(Vec::from([1, 3])); + assert_eq!(a.values().as_slice(), [1, 3]); +} + +#[test] +fn try_from_trusted_len_iter() { + let iter = std::iter::repeat(Some(1)).take(2).map(PolarsResult::Ok); + let a = MutablePrimitiveArray::try_from_trusted_len_iter(iter).unwrap(); + assert_eq!(a, MutablePrimitiveArray::from([Some(1), Some(1)])); +} + +#[test] +fn wrong_data_type() { + assert!(MutablePrimitiveArray::::try_new(ArrowDataType::Utf8, vec![], None).is_err()); +} + +#[test] +fn extend_from_self() { + let mut a = MutablePrimitiveArray::from([Some(1), None]); + + a.try_extend_from_self(&a.clone()).unwrap(); + + assert_eq!( + a, + MutablePrimitiveArray::from([Some(1), None, Some(1), None]) + ); +} diff --git a/crates/polars/tests/it/arrow/array/primitive/to_mutable.rs b/crates/polars/tests/it/arrow/array/primitive/to_mutable.rs new file mode 100644 index 0000000000000..0cc32155a3181 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/primitive/to_mutable.rs @@ -0,0 +1,53 @@ +use arrow::array::PrimitiveArray; +use arrow::bitmap::Bitmap; +use arrow::datatypes::ArrowDataType; +use either::Either; + +#[test] +fn array_to_mutable() { + let data = vec![1, 2, 3]; + let arr = PrimitiveArray::new(ArrowDataType::Int32, data.into(), None); + + // to mutable push and freeze again + let mut mut_arr = arr.into_mut().unwrap_right(); + mut_arr.push(Some(5)); + let immut: PrimitiveArray = mut_arr.into(); + assert_eq!(immut.values().as_slice(), [1, 2, 3, 5]); + + // let's cause a realloc and see if miri is ok + let mut mut_arr = immut.into_mut().unwrap_right(); + mut_arr.extend_constant(256, Some(9)); + let immut: PrimitiveArray = mut_arr.into(); + assert_eq!(immut.values().len(), 256 + 4); +} + +#[test] +fn array_to_mutable_not_owned() { + let data = vec![1, 2, 3]; + let arr = PrimitiveArray::new(ArrowDataType::Int32, data.into(), None); + let arr2 = arr.clone(); + + // to the `to_mutable` should fail and we should get back the original array + match arr2.into_mut() { + Either::Left(arr2) => { + assert_eq!(arr, arr2); + }, + _ => panic!(), + } +} + +#[test] +#[allow(clippy::redundant_clone)] +fn array_to_mutable_validity() { + let data = vec![1, 2, 3]; + + // both have a single reference should be ok + let bitmap = Bitmap::from_iter([true, false, true]); + let arr = PrimitiveArray::new(ArrowDataType::Int32, data.clone().into(), Some(bitmap)); + assert!(matches!(arr.into_mut(), Either::Right(_))); + + // now we clone the bitmap increasing the ref count + let bitmap = Bitmap::from_iter([true, false, true]); + let arr = PrimitiveArray::new(ArrowDataType::Int32, data.into(), Some(bitmap.clone())); + assert!(matches!(arr.into_mut(), Either::Left(_))); +} diff --git a/crates/polars/tests/it/arrow/array/struct_/iterator.rs b/crates/polars/tests/it/arrow/array/struct_/iterator.rs new file mode 100644 index 0000000000000..5b4b0b784d136 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/struct_/iterator.rs @@ -0,0 +1,28 @@ +use arrow::array::*; +use arrow::datatypes::*; +use arrow::scalar::new_scalar; + +#[test] +fn test_simple_iter() { + let boolean = BooleanArray::from_slice([false, false, true, true]).boxed(); + let int = Int32Array::from_slice([42, 28, 19, 31]).boxed(); + + let fields = vec![ + Field::new("b", ArrowDataType::Boolean, false), + Field::new("c", ArrowDataType::Int32, false), + ]; + + let array = StructArray::new( + ArrowDataType::Struct(fields), + vec![boolean.clone(), int.clone()], + None, + ); + + for (i, item) in array.iter().enumerate() { + let expected = Some(vec![ + new_scalar(boolean.as_ref(), i), + new_scalar(int.as_ref(), i), + ]); + assert_eq!(expected, item); + } +} diff --git a/crates/polars/tests/it/arrow/array/struct_/mod.rs b/crates/polars/tests/it/arrow/array/struct_/mod.rs new file mode 100644 index 0000000000000..ae1a0c0a37cba --- /dev/null +++ b/crates/polars/tests/it/arrow/array/struct_/mod.rs @@ -0,0 +1,27 @@ +mod iterator; +mod mutable; + +use arrow::array::*; +use arrow::bitmap::Bitmap; +use arrow::datatypes::*; + +#[test] +fn debug() { + let boolean = BooleanArray::from_slice([false, false, true, true]).boxed(); + let int = Int32Array::from_slice([42, 28, 19, 31]).boxed(); + + let fields = vec![ + Field::new("b", ArrowDataType::Boolean, false), + Field::new("c", ArrowDataType::Int32, false), + ]; + + let array = StructArray::new( + ArrowDataType::Struct(fields), + vec![boolean.clone(), int.clone()], + Some(Bitmap::from([true, true, false, true])), + ); + assert_eq!( + format!("{array:?}"), + "StructArray[{b: false, c: 42}, {b: false, c: 28}, None, {b: true, c: 31}]" + ); +} diff --git a/crates/polars/tests/it/arrow/array/struct_/mutable.rs b/crates/polars/tests/it/arrow/array/struct_/mutable.rs new file mode 100644 index 0000000000000..e9d698aa1bb3f --- /dev/null +++ b/crates/polars/tests/it/arrow/array/struct_/mutable.rs @@ -0,0 +1,31 @@ +use arrow::array::*; +use arrow::datatypes::{ArrowDataType, Field}; + +#[test] +fn push() { + let c1 = Box::new(MutablePrimitiveArray::::new()) as Box; + let values = vec![c1]; + let data_type = ArrowDataType::Struct(vec![Field::new("f1", ArrowDataType::Int32, true)]); + let mut a = MutableStructArray::new(data_type, values); + + a.value::>(0) + .unwrap() + .push(Some(1)); + a.push(true); + a.value::>(0).unwrap().push(None); + a.push(false); + a.value::>(0) + .unwrap() + .push(Some(2)); + a.push(true); + + assert_eq!(a.len(), 3); + assert!(a.is_valid(0)); + assert!(!a.is_valid(1)); + assert!(a.is_valid(2)); + + assert_eq!( + a.value::>(0).unwrap().values(), + &Vec::from([1, 0, 2]) + ); +} diff --git a/crates/polars/tests/it/arrow/array/union.rs b/crates/polars/tests/it/arrow/array/union.rs new file mode 100644 index 0000000000000..b358aa8e44bb8 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/union.rs @@ -0,0 +1,371 @@ +use arrow::array::*; +use arrow::buffer::Buffer; +use arrow::datatypes::*; +use arrow::scalar::{new_scalar, PrimitiveScalar, Scalar, UnionScalar, Utf8Scalar}; +use polars_error::PolarsResult; + +fn next_unwrap(iter: &mut I) -> T +where + I: Iterator>, + T: Clone + 'static, +{ + iter.next() + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .clone() +} + +#[test] +fn sparse_debug() -> PolarsResult<()> { + let fields = vec![ + Field::new("a", ArrowDataType::Int32, true), + Field::new("b", ArrowDataType::Utf8, true), + ]; + let data_type = ArrowDataType::Union(fields, None, UnionMode::Sparse); + let types = vec![0, 0, 1].into(); + let fields = vec![ + Int32Array::from(&[Some(1), None, Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + let array = UnionArray::new(data_type, types, fields, None); + + assert_eq!(format!("{array:?}"), "UnionArray[1, None, c]"); + + Ok(()) +} + +#[test] +fn dense_debug() -> PolarsResult<()> { + let fields = vec![ + Field::new("a", ArrowDataType::Int32, true), + Field::new("b", ArrowDataType::Utf8, true), + ]; + let data_type = ArrowDataType::Union(fields, None, UnionMode::Dense); + let types = vec![0, 0, 1].into(); + let fields = vec![ + Int32Array::from(&[Some(1), None, Some(2)]).boxed(), + Utf8Array::::from([Some("c")]).boxed(), + ]; + let offsets = Some(vec![0, 1, 0].into()); + + let array = UnionArray::new(data_type, types, fields, offsets); + + assert_eq!(format!("{array:?}"), "UnionArray[1, None, c]"); + + Ok(()) +} + +#[test] +fn slice() -> PolarsResult<()> { + let fields = vec![ + Field::new("a", ArrowDataType::Int32, true), + Field::new("b", ArrowDataType::LargeUtf8, true), + ]; + let data_type = ArrowDataType::Union(fields, None, UnionMode::Sparse); + let types = Buffer::from(vec![0, 0, 1]); + let fields = vec![ + Int32Array::from(&[Some(1), None, Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + let array = UnionArray::new(data_type.clone(), types, fields.clone(), None); + + let result = array.sliced(1, 2); + + let sliced_types = Buffer::from(vec![0, 1]); + let sliced_fields = vec![ + Int32Array::from(&[None, Some(2)]).boxed(), + Utf8Array::::from([Some("b"), Some("c")]).boxed(), + ]; + let expected = UnionArray::new(data_type, sliced_types, sliced_fields, None); + + assert_eq!(expected, result); + Ok(()) +} + +#[test] +fn iter_sparse() -> PolarsResult<()> { + let fields = vec![ + Field::new("a", ArrowDataType::Int32, true), + Field::new("b", ArrowDataType::Utf8, true), + ]; + let data_type = ArrowDataType::Union(fields, None, UnionMode::Sparse); + let types = Buffer::from(vec![0, 0, 1]); + let fields = vec![ + Int32Array::from(&[Some(1), None, Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + let array = UnionArray::new(data_type, types, fields.clone(), None); + let mut iter = array.iter(); + + assert_eq!( + next_unwrap::, _>(&mut iter).value(), + &Some(1) + ); + assert_eq!( + next_unwrap::, _>(&mut iter).value(), + &None + ); + assert_eq!( + next_unwrap::, _>(&mut iter).value(), + Some("c") + ); + assert_eq!(iter.next(), None); + + Ok(()) +} + +#[test] +fn iter_dense() -> PolarsResult<()> { + let fields = vec![ + Field::new("a", ArrowDataType::Int32, true), + Field::new("b", ArrowDataType::Utf8, true), + ]; + let data_type = ArrowDataType::Union(fields, None, UnionMode::Dense); + let types = Buffer::from(vec![0, 0, 1]); + let offsets = Buffer::::from(vec![0, 1, 0]); + let fields = vec![ + Int32Array::from(&[Some(1), None]).boxed(), + Utf8Array::::from([Some("c")]).boxed(), + ]; + + let array = UnionArray::new(data_type, types, fields.clone(), Some(offsets)); + let mut iter = array.iter(); + + assert_eq!( + next_unwrap::, _>(&mut iter).value(), + &Some(1) + ); + assert_eq!( + next_unwrap::, _>(&mut iter).value(), + &None + ); + assert_eq!( + next_unwrap::, _>(&mut iter).value(), + Some("c") + ); + assert_eq!(iter.next(), None); + + Ok(()) +} + +#[test] +fn iter_sparse_slice() -> PolarsResult<()> { + let fields = vec![ + Field::new("a", ArrowDataType::Int32, true), + Field::new("b", ArrowDataType::Utf8, true), + ]; + let data_type = ArrowDataType::Union(fields, None, UnionMode::Sparse); + let types = Buffer::from(vec![0, 0, 1]); + let fields = vec![ + Int32Array::from(&[Some(1), Some(3), Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + let array = UnionArray::new(data_type, types, fields.clone(), None); + let array_slice = array.sliced(1, 1); + let mut iter = array_slice.iter(); + + assert_eq!( + next_unwrap::, _>(&mut iter).value(), + &Some(3) + ); + assert_eq!(iter.next(), None); + + Ok(()) +} + +#[test] +fn iter_dense_slice() -> PolarsResult<()> { + let fields = vec![ + Field::new("a", ArrowDataType::Int32, true), + Field::new("b", ArrowDataType::Utf8, true), + ]; + let data_type = ArrowDataType::Union(fields, None, UnionMode::Dense); + let types = Buffer::from(vec![0, 0, 1]); + let offsets = Buffer::::from(vec![0, 1, 0]); + let fields = vec![ + Int32Array::from(&[Some(1), Some(3)]).boxed(), + Utf8Array::::from([Some("c")]).boxed(), + ]; + + let array = UnionArray::new(data_type, types, fields.clone(), Some(offsets)); + let array_slice = array.sliced(1, 1); + let mut iter = array_slice.iter(); + + assert_eq!( + next_unwrap::, _>(&mut iter).value(), + &Some(3) + ); + assert_eq!(iter.next(), None); + + Ok(()) +} + +#[test] +fn scalar() -> PolarsResult<()> { + let fields = vec![ + Field::new("a", ArrowDataType::Int32, true), + Field::new("b", ArrowDataType::Utf8, true), + ]; + let data_type = ArrowDataType::Union(fields, None, UnionMode::Dense); + let types = Buffer::from(vec![0, 0, 1]); + let offsets = Buffer::::from(vec![0, 1, 0]); + let fields = vec![ + Int32Array::from(&[Some(1), None]).boxed(), + Utf8Array::::from([Some("c")]).boxed(), + ]; + + let array = UnionArray::new(data_type, types, fields.clone(), Some(offsets)); + + let scalar = new_scalar(&array, 0); + let union_scalar = scalar.as_any().downcast_ref::().unwrap(); + assert_eq!( + union_scalar + .value() + .as_any() + .downcast_ref::>() + .unwrap() + .value(), + &Some(1) + ); + assert_eq!(union_scalar.type_(), 0); + let scalar = new_scalar(&array, 1); + let union_scalar = scalar.as_any().downcast_ref::().unwrap(); + assert_eq!( + union_scalar + .value() + .as_any() + .downcast_ref::>() + .unwrap() + .value(), + &None + ); + assert_eq!(union_scalar.type_(), 0); + + let scalar = new_scalar(&array, 2); + let union_scalar = scalar.as_any().downcast_ref::().unwrap(); + assert_eq!( + union_scalar + .value() + .as_any() + .downcast_ref::>() + .unwrap() + .value(), + Some("c") + ); + assert_eq!(union_scalar.type_(), 1); + + Ok(()) +} + +#[test] +fn dense_without_offsets_is_error() { + let fields = vec![ + Field::new("a", ArrowDataType::Int32, true), + Field::new("b", ArrowDataType::Utf8, true), + ]; + let data_type = ArrowDataType::Union(fields, None, UnionMode::Dense); + let types = vec![0, 0, 1].into(); + let fields = vec![ + Int32Array::from([Some(1), Some(3), Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + assert!(UnionArray::try_new(data_type, types, fields.clone(), None).is_err()); +} + +#[test] +fn fields_must_match() { + let fields = vec![ + Field::new("a", ArrowDataType::Int64, true), + Field::new("b", ArrowDataType::Utf8, true), + ]; + let data_type = ArrowDataType::Union(fields, None, UnionMode::Sparse); + let types = vec![0, 0, 1].into(); + let fields = vec![ + Int32Array::from([Some(1), Some(3), Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + assert!(UnionArray::try_new(data_type, types, fields.clone(), None).is_err()); +} + +#[test] +fn sparse_with_offsets_is_error() { + let fields = vec![ + Field::new("a", ArrowDataType::Int32, true), + Field::new("b", ArrowDataType::Utf8, true), + ]; + let data_type = ArrowDataType::Union(fields, None, UnionMode::Sparse); + let fields = vec![ + Int32Array::from([Some(1), Some(3), Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + let types = vec![0, 0, 1].into(); + let offsets = vec![0, 1, 0].into(); + + assert!(UnionArray::try_new(data_type, types, fields.clone(), Some(offsets)).is_err()); +} + +#[test] +fn offsets_must_be_in_bounds() { + let fields = vec![ + Field::new("a", ArrowDataType::Int32, true), + Field::new("b", ArrowDataType::Utf8, true), + ]; + let data_type = ArrowDataType::Union(fields, None, UnionMode::Sparse); + let fields = vec![ + Int32Array::from([Some(1), Some(3), Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + let types = vec![0, 0, 1].into(); + // it must be equal to length og types + let offsets = vec![0, 1].into(); + + assert!(UnionArray::try_new(data_type, types, fields.clone(), Some(offsets)).is_err()); +} + +#[test] +fn sparse_with_wrong_offsets1_is_error() { + let fields = vec![ + Field::new("a", ArrowDataType::Int32, true), + Field::new("b", ArrowDataType::Utf8, true), + ]; + let data_type = ArrowDataType::Union(fields, None, UnionMode::Sparse); + let fields = vec![ + Int32Array::from([Some(1), Some(3), Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + let types = vec![0, 0, 1].into(); + // it must be equal to length of types + let offsets = vec![0, 1, 10].into(); + + assert!(UnionArray::try_new(data_type, types, fields.clone(), Some(offsets)).is_err()); +} + +#[test] +fn types_must_be_in_bounds() -> PolarsResult<()> { + let fields = vec![ + Field::new("a", ArrowDataType::Int32, true), + Field::new("b", ArrowDataType::Utf8, true), + ]; + let data_type = ArrowDataType::Union(fields, None, UnionMode::Sparse); + let fields = vec![ + Int32Array::from([Some(1), Some(3), Some(2)]).boxed(), + Utf8Array::::from([Some("a"), Some("b"), Some("c")]).boxed(), + ]; + + // 10 > num fields + let types = vec![0, 10].into(); + + assert!(UnionArray::try_new(data_type, types, fields.clone(), None).is_err()); + Ok(()) +} diff --git a/crates/polars/tests/it/arrow/array/utf8/mod.rs b/crates/polars/tests/it/arrow/array/utf8/mod.rs new file mode 100644 index 0000000000000..fb75990dad29b --- /dev/null +++ b/crates/polars/tests/it/arrow/array/utf8/mod.rs @@ -0,0 +1,237 @@ +use arrow::array::*; +use arrow::bitmap::Bitmap; +use arrow::buffer::Buffer; +use arrow::datatypes::ArrowDataType; +use arrow::offset::OffsetsBuffer; +use polars_error::PolarsResult; + +mod mutable; +mod mutable_values; +mod to_mutable; + +#[test] +fn basics() { + let data = vec![Some("hello"), None, Some("hello2")]; + + let array: Utf8Array = data.into_iter().collect(); + + assert_eq!(array.value(0), "hello"); + assert_eq!(array.value(1), ""); + assert_eq!(array.value(2), "hello2"); + assert_eq!(unsafe { array.value_unchecked(2) }, "hello2"); + assert_eq!(array.values().as_slice(), b"hellohello2"); + assert_eq!(array.offsets().as_slice(), &[0, 5, 5, 11]); + assert_eq!( + array.validity(), + Some(&Bitmap::from_u8_slice([0b00000101], 3)) + ); + assert!(array.is_valid(0)); + assert!(!array.is_valid(1)); + assert!(array.is_valid(2)); + + let array2 = Utf8Array::::new( + ArrowDataType::Utf8, + array.offsets().clone(), + array.values().clone(), + array.validity().cloned(), + ); + assert_eq!(array, array2); + + let array = array.sliced(1, 2); + assert_eq!(array.value(0), ""); + assert_eq!(array.value(1), "hello2"); + // note how this keeps everything: the offsets were sliced + assert_eq!(array.values().as_slice(), b"hellohello2"); + assert_eq!(array.offsets().as_slice(), &[5, 5, 11]); +} + +#[test] +fn empty() { + let array = Utf8Array::::new_empty(ArrowDataType::Utf8); + assert_eq!(array.values().as_slice(), b""); + assert_eq!(array.offsets().as_slice(), &[0]); + assert_eq!(array.validity(), None); +} + +#[test] +fn from() { + let array = Utf8Array::::from([Some("hello"), Some(" "), None]); + + let a = array.validity().unwrap(); + assert_eq!(a, &Bitmap::from([true, true, false])); +} + +#[test] +fn from_slice() { + let b = Utf8Array::::from_slice(["a", "b", "cc"]); + + let offsets = vec![0, 1, 2, 4].try_into().unwrap(); + let values = b"abcc".to_vec().into(); + assert_eq!( + b, + Utf8Array::::new(ArrowDataType::Utf8, offsets, values, None) + ); +} + +#[test] +fn from_iter_values() { + let b = Utf8Array::::from_iter_values(["a", "b", "cc"].iter()); + + let offsets = vec![0, 1, 2, 4].try_into().unwrap(); + let values = b"abcc".to_vec().into(); + assert_eq!( + b, + Utf8Array::::new(ArrowDataType::Utf8, offsets, values, None) + ); +} + +#[test] +fn from_trusted_len_iter() { + let b = + Utf8Array::::from_trusted_len_iter(vec![Some("a"), Some("b"), Some("cc")].into_iter()); + + let offsets = vec![0, 1, 2, 4].try_into().unwrap(); + let values = b"abcc".to_vec().into(); + assert_eq!( + b, + Utf8Array::::new(ArrowDataType::Utf8, offsets, values, None) + ); +} + +#[test] +fn try_from_trusted_len_iter() { + let b = Utf8Array::::try_from_trusted_len_iter( + vec![Some("a"), Some("b"), Some("cc")] + .into_iter() + .map(PolarsResult::Ok), + ) + .unwrap(); + + let offsets = vec![0, 1, 2, 4].try_into().unwrap(); + let values = b"abcc".to_vec().into(); + assert_eq!( + b, + Utf8Array::::new(ArrowDataType::Utf8, offsets, values, None) + ); +} + +#[test] +fn not_utf8() { + let offsets = vec![0, 4].try_into().unwrap(); + let values = vec![0, 159, 146, 150].into(); // invalid utf8 + assert!(Utf8Array::::try_new(ArrowDataType::Utf8, offsets, values, None).is_err()); +} + +#[test] +fn not_utf8_individually() { + let offsets = vec![0, 1, 2].try_into().unwrap(); + let values = vec![207, 128].into(); // each is invalid utf8, but together is valid + assert!(Utf8Array::::try_new(ArrowDataType::Utf8, offsets, values, None).is_err()); +} + +#[test] +fn wrong_data_type() { + let offsets = vec![0, 4].try_into().unwrap(); + let values = b"abbb".to_vec().into(); + assert!(Utf8Array::::try_new(ArrowDataType::Int32, offsets, values, None).is_err()); +} + +#[test] +fn out_of_bounds_offsets_panics() { + // the 10 is out of bounds + let offsets = vec![0, 10, 11].try_into().unwrap(); + let values = b"abbb".to_vec().into(); + assert!(Utf8Array::::try_new(ArrowDataType::Utf8, offsets, values, None).is_err()); +} + +#[test] +#[should_panic] +fn index_out_of_bounds_panics() { + let offsets = vec![0, 1, 2, 4].try_into().unwrap(); + let values = b"abbb".to_vec().into(); + let array = Utf8Array::::new(ArrowDataType::Utf8, offsets, values, None); + + array.value(3); +} + +#[test] +fn debug() { + let array = Utf8Array::::from([Some("aa"), Some(""), None]); + + assert_eq!(format!("{array:?}"), "Utf8Array[aa, , None]"); +} + +#[test] +fn into_mut_1() { + let offsets = vec![0, 1].try_into().unwrap(); + let values = Buffer::from(b"a".to_vec()); + let a = values.clone(); // cloned values + assert_eq!(a, values); + let array = Utf8Array::::new(ArrowDataType::Utf8, offsets, values, None); + assert!(array.into_mut().is_left()); +} + +#[test] +fn into_mut_2() { + let offsets: OffsetsBuffer = vec![0, 1].try_into().unwrap(); + let values = b"a".to_vec().into(); + let a = offsets.clone(); // cloned offsets + assert_eq!(a, offsets); + let array = Utf8Array::::new(ArrowDataType::Utf8, offsets, values, None); + assert!(array.into_mut().is_left()); +} + +#[test] +fn into_mut_3() { + let offsets = vec![0, 1].try_into().unwrap(); + let values = b"a".to_vec().into(); + let validity = Some([true].into()); + let a = validity.clone(); // cloned validity + assert_eq!(a, validity); + let array = Utf8Array::::new(ArrowDataType::Utf8, offsets, values, validity); + assert!(array.into_mut().is_left()); +} + +#[test] +fn into_mut_4() { + let offsets = vec![0, 1].try_into().unwrap(); + let values = b"a".to_vec().into(); + let validity = Some([true].into()); + let array = Utf8Array::::new(ArrowDataType::Utf8, offsets, values, validity); + assert!(array.into_mut().is_right()); +} + +#[test] +fn rev_iter() { + let array = Utf8Array::::from([Some("hello"), Some(" "), None]); + + assert_eq!( + array.into_iter().rev().collect::>(), + vec![None, Some(" "), Some("hello")] + ); +} + +#[test] +fn iter_nth() { + let array = Utf8Array::::from([Some("hello"), Some(" "), None]); + + assert_eq!(array.iter().nth(1), Some(Some(" "))); + assert_eq!(array.iter().nth(10), None); +} + +#[test] +fn test_apply_validity() { + let mut array = Utf8Array::::from([Some("Red"), Some("Green"), Some("Blue")]); + array.set_validity(Some([true, true, true].into())); + + array.apply_validity(|bitmap| { + let mut mut_bitmap = bitmap.into_mut().right().unwrap(); + mut_bitmap.set(1, false); + mut_bitmap.set(2, false); + mut_bitmap.into() + }); + + assert!(array.is_valid(0)); + assert!(!array.is_valid(1)); + assert!(!array.is_valid(2)); +} diff --git a/crates/polars/tests/it/arrow/array/utf8/mutable.rs b/crates/polars/tests/it/arrow/array/utf8/mutable.rs new file mode 100644 index 0000000000000..8db873a90d10b --- /dev/null +++ b/crates/polars/tests/it/arrow/array/utf8/mutable.rs @@ -0,0 +1,242 @@ +use arrow::array::{MutableArray, MutableUtf8Array, TryExtendFromSelf, Utf8Array}; +use arrow::bitmap::Bitmap; +use arrow::datatypes::ArrowDataType; + +#[test] +fn capacities() { + let b = MutableUtf8Array::::with_capacities(1, 10); + + assert!(b.values().capacity() >= 10); + assert!(b.offsets().capacity() >= 1); +} + +#[test] +fn push_null() { + let mut array = MutableUtf8Array::::new(); + array.push::<&str>(None); + + let array: Utf8Array = array.into(); + assert_eq!(array.validity(), Some(&Bitmap::from([false]))); +} + +#[test] +fn pop() { + let mut a = MutableUtf8Array::::new(); + a.push(Some("first")); + a.push(Some("second")); + a.push(Some("third")); + a.push::<&str>(None); + + assert_eq!(a.pop(), None); + assert_eq!(a.len(), 3); + assert_eq!(a.pop(), Some("third".to_owned())); + assert_eq!(a.len(), 2); + assert_eq!(a.pop(), Some("second".to_string())); + assert_eq!(a.len(), 1); + assert_eq!(a.pop(), Some("first".to_string())); + assert!(a.is_empty()); + assert_eq!(a.pop(), None); + assert!(a.is_empty()); +} + +#[test] +fn pop_all_some() { + let mut a = MutableUtf8Array::::new(); + a.push(Some("first")); + a.push(Some("second")); + a.push(Some("third")); + a.push(Some("fourth")); + for _ in 0..4 { + a.push(Some("aaaa")); + } + a.push(Some("こんにちは")); + + assert_eq!(a.pop(), Some("こんにちは".to_string())); + assert_eq!(a.pop(), Some("aaaa".to_string())); + assert_eq!(a.pop(), Some("aaaa".to_string())); + assert_eq!(a.pop(), Some("aaaa".to_string())); + assert_eq!(a.len(), 5); + assert_eq!(a.pop(), Some("aaaa".to_string())); + assert_eq!(a.pop(), Some("fourth".to_string())); + assert_eq!(a.pop(), Some("third".to_string())); + assert_eq!(a.pop(), Some("second".to_string())); + assert_eq!(a.pop(), Some("first".to_string())); + assert!(a.is_empty()); + assert_eq!(a.pop(), None); +} + +/// Safety guarantee +#[test] +fn not_utf8() { + let offsets = vec![0, 4].try_into().unwrap(); + let values = vec![0, 159, 146, 150]; // invalid utf8 + assert!(MutableUtf8Array::::try_new(ArrowDataType::Utf8, offsets, values, None).is_err()); +} + +#[test] +fn wrong_data_type() { + let offsets = vec![0, 4].try_into().unwrap(); + let values = vec![1, 2, 3, 4]; + assert!(MutableUtf8Array::::try_new(ArrowDataType::Int8, offsets, values, None).is_err()); +} + +#[test] +fn test_extend_trusted_len_values() { + let mut array = MutableUtf8Array::::new(); + + array.extend_trusted_len_values(["hi", "there"].iter()); + array.extend_trusted_len_values(["hello"].iter()); + array.extend_trusted_len(vec![Some("again"), None].into_iter()); + + let array: Utf8Array = array.into(); + + assert_eq!(array.values().as_slice(), b"hitherehelloagain"); + assert_eq!(array.offsets().as_slice(), &[0, 2, 7, 12, 17, 17]); + assert_eq!( + array.validity(), + Some(&Bitmap::from_u8_slice([0b00001111], 5)) + ); +} + +#[test] +fn test_extend_trusted_len() { + let mut array = MutableUtf8Array::::new(); + + array.extend_trusted_len(vec![Some("hi"), Some("there")].into_iter()); + array.extend_trusted_len(vec![None, Some("hello")].into_iter()); + array.extend_trusted_len_values(["again"].iter()); + + let array: Utf8Array = array.into(); + + assert_eq!(array.values().as_slice(), b"hitherehelloagain"); + assert_eq!(array.offsets().as_slice(), &[0, 2, 7, 7, 12, 17]); + assert_eq!( + array.validity(), + Some(&Bitmap::from_u8_slice([0b00011011], 5)) + ); +} + +#[test] +fn test_extend_values() { + let mut array = MutableUtf8Array::::new(); + + array.extend_values([Some("hi"), None, Some("there"), None].iter().flatten()); + array.extend_values([Some("hello"), None].iter().flatten()); + array.extend_values(vec![Some("again"), None].into_iter().flatten()); + + let array: Utf8Array = array.into(); + + assert_eq!(array.values().as_slice(), b"hitherehelloagain"); + assert_eq!(array.offsets().as_slice(), &[0, 2, 7, 12, 17]); + assert_eq!(array.validity(), None,); +} + +#[test] +fn test_extend() { + let mut array = MutableUtf8Array::::new(); + + array.extend([Some("hi"), None, Some("there"), None]); + + let array: Utf8Array = array.into(); + + assert_eq!( + array, + Utf8Array::::from([Some("hi"), None, Some("there"), None]) + ); +} + +#[test] +fn as_arc() { + let mut array = MutableUtf8Array::::new(); + + array.extend([Some("hi"), None, Some("there"), None]); + + assert_eq!( + Utf8Array::::from([Some("hi"), None, Some("there"), None]), + array.as_arc().as_ref() + ); +} + +#[test] +fn test_iter() { + let mut array = MutableUtf8Array::::new(); + + array.extend_trusted_len(vec![Some("hi"), Some("there")].into_iter()); + array.extend_trusted_len(vec![None, Some("hello")].into_iter()); + array.extend_trusted_len_values(["again"].iter()); + + let result = array.iter().collect::>(); + assert_eq!( + result, + vec![ + Some("hi"), + Some("there"), + None, + Some("hello"), + Some("again"), + ] + ); +} + +#[test] +fn as_box_twice() { + let mut a = MutableUtf8Array::::new(); + let _ = a.as_box(); + let _ = a.as_box(); + let mut a = MutableUtf8Array::::new(); + let _ = a.as_arc(); + let _ = a.as_arc(); +} + +#[test] +fn extend_from_self() { + let mut a = MutableUtf8Array::::from([Some("aa"), None]); + + a.try_extend_from_self(&a.clone()).unwrap(); + + assert_eq!( + a, + MutableUtf8Array::::from([Some("aa"), None, Some("aa"), None]) + ); +} + +#[test] +fn test_set_validity() { + let mut array = MutableUtf8Array::::from([Some("Red"), Some("Green"), Some("Blue")]); + array.set_validity(Some([false, false, true].into())); + + assert!(!array.is_valid(0)); + assert!(!array.is_valid(1)); + assert!(array.is_valid(2)); +} + +#[test] +fn test_apply_validity() { + let mut array = MutableUtf8Array::::from([Some("Red"), Some("Green"), Some("Blue")]); + array.set_validity(Some([true, true, true].into())); + + array.apply_validity(|mut mut_bitmap| { + mut_bitmap.set(1, false); + mut_bitmap.set(2, false); + mut_bitmap + }); + + assert!(array.is_valid(0)); + assert!(!array.is_valid(1)); + assert!(!array.is_valid(2)); +} + +#[test] +fn test_apply_validity_with_no_validity_inited() { + let mut array = MutableUtf8Array::::from([Some("Red"), Some("Green"), Some("Blue")]); + + array.apply_validity(|mut mut_bitmap| { + mut_bitmap.set(1, false); + mut_bitmap.set(2, false); + mut_bitmap + }); + + assert!(array.is_valid(0)); + assert!(array.is_valid(1)); + assert!(array.is_valid(2)); +} diff --git a/crates/polars/tests/it/arrow/array/utf8/mutable_values.rs b/crates/polars/tests/it/arrow/array/utf8/mutable_values.rs new file mode 100644 index 0000000000000..d4a3099499349 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/utf8/mutable_values.rs @@ -0,0 +1,105 @@ +use arrow::array::{MutableArray, MutableUtf8ValuesArray}; +use arrow::datatypes::ArrowDataType; + +#[test] +fn capacity() { + let mut b = MutableUtf8ValuesArray::::with_capacity(100); + + assert_eq!(b.values().capacity(), 0); + assert!(b.offsets().capacity() >= 100); + b.shrink_to_fit(); + assert!(b.offsets().capacity() < 100); +} + +#[test] +fn offsets_must_be_in_bounds() { + let offsets = vec![0, 10].try_into().unwrap(); + let values = b"abbbbb".to_vec(); + assert!(MutableUtf8ValuesArray::::try_new(ArrowDataType::Utf8, offsets, values).is_err()); +} + +#[test] +fn data_type_must_be_consistent() { + let offsets = vec![0, 4].try_into().unwrap(); + let values = b"abbb".to_vec(); + assert!(MutableUtf8ValuesArray::::try_new(ArrowDataType::Int32, offsets, values).is_err()); +} + +#[test] +fn must_be_utf8() { + let offsets = vec![0, 4].try_into().unwrap(); + let values = vec![0, 159, 146, 150]; + assert!(std::str::from_utf8(&values).is_err()); + assert!(MutableUtf8ValuesArray::::try_new(ArrowDataType::Utf8, offsets, values).is_err()); +} + +#[test] +fn as_box() { + let offsets = vec![0, 2].try_into().unwrap(); + let values = b"ab".to_vec(); + let mut b = + MutableUtf8ValuesArray::::try_new(ArrowDataType::Utf8, offsets, values).unwrap(); + let _ = b.as_box(); +} + +#[test] +fn as_arc() { + let offsets = vec![0, 2].try_into().unwrap(); + let values = b"ab".to_vec(); + let mut b = + MutableUtf8ValuesArray::::try_new(ArrowDataType::Utf8, offsets, values).unwrap(); + let _ = b.as_arc(); +} + +#[test] +fn extend_trusted_len() { + let offsets = vec![0, 2].try_into().unwrap(); + let values = b"ab".to_vec(); + let mut b = + MutableUtf8ValuesArray::::try_new(ArrowDataType::Utf8, offsets, values).unwrap(); + b.extend_trusted_len(vec!["a", "b"].into_iter()); + + let offsets = vec![0, 2, 3, 4].try_into().unwrap(); + let values = b"abab".to_vec(); + assert_eq!( + b.as_box(), + MutableUtf8ValuesArray::::try_new(ArrowDataType::Utf8, offsets, values) + .unwrap() + .as_box() + ) +} + +#[test] +fn from_trusted_len() { + let mut b = MutableUtf8ValuesArray::::from_trusted_len_iter(vec!["a", "b"].into_iter()); + + let offsets = vec![0, 1, 2].try_into().unwrap(); + let values = b"ab".to_vec(); + assert_eq!( + b.as_box(), + MutableUtf8ValuesArray::::try_new(ArrowDataType::Utf8, offsets, values) + .unwrap() + .as_box() + ) +} + +#[test] +fn extend_from_iter() { + let offsets = vec![0, 2].try_into().unwrap(); + let values = b"ab".to_vec(); + let mut b = + MutableUtf8ValuesArray::::try_new(ArrowDataType::Utf8, offsets, values).unwrap(); + b.extend_trusted_len(vec!["a", "b"].into_iter()); + + let a = b.clone(); + b.extend_trusted_len(a.iter()); + + let offsets = vec![0, 2, 3, 4, 6, 7, 8].try_into().unwrap(); + let values = b"abababab".to_vec(); + assert_eq!( + b.as_box(), + MutableUtf8ValuesArray::::try_new(ArrowDataType::Utf8, offsets, values) + .unwrap() + .as_box() + ) +} diff --git a/crates/polars/tests/it/arrow/array/utf8/to_mutable.rs b/crates/polars/tests/it/arrow/array/utf8/to_mutable.rs new file mode 100644 index 0000000000000..5f2624368bf70 --- /dev/null +++ b/crates/polars/tests/it/arrow/array/utf8/to_mutable.rs @@ -0,0 +1,71 @@ +use arrow::array::Utf8Array; +use arrow::bitmap::Bitmap; +use arrow::buffer::Buffer; +use arrow::datatypes::ArrowDataType; +use arrow::offset::OffsetsBuffer; + +#[test] +fn not_shared() { + let array = Utf8Array::::from([Some("hello"), Some(" "), None]); + assert!(array.into_mut().is_right()); +} + +#[test] +#[allow(clippy::redundant_clone)] +fn shared_validity() { + let validity = Bitmap::from([true]); + let array = Utf8Array::::new( + ArrowDataType::Utf8, + vec![0, 1].try_into().unwrap(), + b"a".to_vec().into(), + Some(validity.clone()), + ); + assert!(array.into_mut().is_left()) +} + +#[test] +#[allow(clippy::redundant_clone)] +fn shared_values() { + let values: Buffer = b"a".to_vec().into(); + let array = Utf8Array::::new( + ArrowDataType::Utf8, + vec![0, 1].try_into().unwrap(), + values.clone(), + Some(Bitmap::from([true])), + ); + assert!(array.into_mut().is_left()) +} + +#[test] +#[allow(clippy::redundant_clone)] +fn shared_offsets_values() { + let offsets: OffsetsBuffer = vec![0, 1].try_into().unwrap(); + let values: Buffer = b"a".to_vec().into(); + let array = Utf8Array::::new( + ArrowDataType::Utf8, + offsets.clone(), + values.clone(), + Some(Bitmap::from([true])), + ); + assert!(array.into_mut().is_left()) +} + +#[test] +#[allow(clippy::redundant_clone)] +fn shared_offsets() { + let offsets: OffsetsBuffer = vec![0, 1].try_into().unwrap(); + let array = Utf8Array::::new( + ArrowDataType::Utf8, + offsets.clone(), + b"a".to_vec().into(), + Some(Bitmap::from([true])), + ); + assert!(array.into_mut().is_left()) +} + +#[test] +#[allow(clippy::redundant_clone)] +fn shared_all() { + let array = Utf8Array::::from([Some("hello"), Some(" "), None]); + assert!(array.clone().into_mut().is_left()) +} diff --git a/crates/polars/tests/it/arrow/bitmap/assign_ops.rs b/crates/polars/tests/it/arrow/bitmap/assign_ops.rs new file mode 100644 index 0000000000000..939133f0a5fea --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/assign_ops.rs @@ -0,0 +1,78 @@ +use arrow::bitmap::{binary_assign, unary_assign, Bitmap, MutableBitmap}; +use proptest::prelude::*; + +use super::bitmap_strategy; + +#[test] +fn basics() { + let mut b = MutableBitmap::from_iter(std::iter::repeat(true).take(10)); + unary_assign(&mut b, |x: u8| !x); + assert_eq!( + b, + MutableBitmap::from_iter(std::iter::repeat(false).take(10)) + ); + + let mut b = MutableBitmap::from_iter(std::iter::repeat(true).take(10)); + let c = Bitmap::from_iter(std::iter::repeat(true).take(10)); + binary_assign(&mut b, &c, |x: u8, y| x | y); + assert_eq!( + b, + MutableBitmap::from_iter(std::iter::repeat(true).take(10)) + ); +} + +#[test] +fn binary_assign_oob() { + // this check we don't have an oob access if the bitmaps are size T + 1 + // and we do some slicing. + let a = MutableBitmap::from_iter(std::iter::repeat(true).take(65)); + let b = MutableBitmap::from_iter(std::iter::repeat(true).take(65)); + + let a: Bitmap = a.into(); + let a = a.sliced(10, 20); + + let b: Bitmap = b.into(); + let b = b.sliced(10, 20); + + let mut a = a.make_mut(); + + binary_assign(&mut a, &b, |x: u64, y| x & y); +} + +#[test] +fn fast_paths() { + let b = MutableBitmap::from([true, false]); + let c = Bitmap::from_iter([true, true]); + let b = b & &c; + assert_eq!(b, MutableBitmap::from_iter([true, false])); + + let b = MutableBitmap::from([true, false]); + let c = Bitmap::from_iter([false, false]); + let b = b & &c; + assert_eq!(b, MutableBitmap::from_iter([false, false])); + + let b = MutableBitmap::from([true, false]); + let c = Bitmap::from_iter([true, true]); + let b = b | &c; + assert_eq!(b, MutableBitmap::from_iter([true, true])); + + let b = MutableBitmap::from([true, false]); + let c = Bitmap::from_iter([false, false]); + let b = b | &c; + assert_eq!(b, MutableBitmap::from_iter([true, false])); +} + +proptest! { + /// Asserts that !bitmap equals all bits flipped + #[test] + #[cfg_attr(miri, ignore)] // miri and proptest do not work well :( + fn not(b in bitmap_strategy()) { + let not_b: MutableBitmap = b.iter().map(|x| !x).collect(); + + let mut b = b.make_mut(); + + unary_assign(&mut b, |x: u8| !x); + + assert_eq!(b, not_b); + } +} diff --git a/crates/polars/tests/it/arrow/bitmap/bitmap_ops.rs b/crates/polars/tests/it/arrow/bitmap/bitmap_ops.rs new file mode 100644 index 0000000000000..e7fb3636e2181 --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/bitmap_ops.rs @@ -0,0 +1,40 @@ +use arrow::bitmap::{and, or, xor, Bitmap}; +use proptest::prelude::*; + +use super::bitmap_strategy; + +proptest! { + /// Asserts that !bitmap equals all bits flipped + #[test] + #[cfg_attr(miri, ignore)] // miri and proptest do not work well :( + fn not(bitmap in bitmap_strategy()) { + let not_bitmap: Bitmap = bitmap.iter().map(|x| !x).collect(); + + assert_eq!(!&bitmap, not_bitmap); + } +} + +#[test] +fn test_fast_paths() { + let all_true = Bitmap::from(&[true, true]); + let all_false = Bitmap::from(&[false, false]); + let toggled = Bitmap::from(&[true, false]); + + assert_eq!(and(&all_true, &all_true), all_true); + assert_eq!(and(&all_false, &all_true), all_false); + assert_eq!(and(&all_true, &all_false), all_false); + assert_eq!(and(&toggled, &all_false), all_false); + assert_eq!(and(&toggled, &all_true), toggled); + + assert_eq!(or(&all_true, &all_true), all_true); + assert_eq!(or(&all_true, &all_false), all_true); + assert_eq!(or(&all_false, &all_true), all_true); + assert_eq!(or(&all_false, &all_false), all_false); + assert_eq!(or(&toggled, &all_false), toggled); + + assert_eq!(xor(&all_true, &all_true), all_false); + assert_eq!(xor(&all_true, &all_false), all_true); + assert_eq!(xor(&all_false, &all_true), all_true); + assert_eq!(xor(&all_false, &all_false), all_false); + assert_eq!(xor(&toggled, &toggled), all_false); +} diff --git a/crates/polars/tests/it/arrow/bitmap/immutable.rs b/crates/polars/tests/it/arrow/bitmap/immutable.rs new file mode 100644 index 0000000000000..29324c96d771b --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/immutable.rs @@ -0,0 +1,67 @@ +use arrow::bitmap::Bitmap; + +#[test] +fn as_slice() { + let b = Bitmap::from([true, true, true, true, true, true, true, true, true]); + + let (slice, offset, length) = b.as_slice(); + assert_eq!(slice, &[0b11111111, 0b1]); + assert_eq!(offset, 0); + assert_eq!(length, 9); +} + +#[test] +fn as_slice_offset() { + let b = Bitmap::from([true, true, true, true, true, true, true, true, true]); + let b = b.sliced(8, 1); + + let (slice, offset, length) = b.as_slice(); + assert_eq!(slice, &[0b1]); + assert_eq!(offset, 0); + assert_eq!(length, 1); +} + +#[test] +fn as_slice_offset_middle() { + let b = Bitmap::from_u8_slice([0, 0, 0, 0b00010101], 27); + let b = b.sliced(22, 5); + + let (slice, offset, length) = b.as_slice(); + assert_eq!(slice, &[0, 0b00010101]); + assert_eq!(offset, 6); + assert_eq!(length, 5); +} + +#[test] +fn debug() { + let b = Bitmap::from([true, true, false, true, true, true, true, true, true]); + let b = b.sliced(2, 7); + + assert_eq!(format!("{b:?}"), "[0b111110__, 0b_______1]"); +} + +#[test] +#[cfg(feature = "arrow")] +fn from_arrow() { + use arrow_buffer::buffer::{BooleanBuffer, NullBuffer}; + let buffer = arrow_buffer::Buffer::from_iter(vec![true, true, true, false, false, false, true]); + let bools = BooleanBuffer::new(buffer, 0, 7); + let nulls = NullBuffer::new(bools); + assert_eq!(nulls.null_count(), 3); + + let bitmap = Bitmap::from_null_buffer(nulls.clone()); + assert_eq!(nulls.null_count(), bitmap.unset_bits()); + assert_eq!(nulls.len(), bitmap.len()); + let back = NullBuffer::from(bitmap); + assert_eq!(nulls, back); + + let nulls = nulls.slice(1, 3); + assert_eq!(nulls.null_count(), 1); + assert_eq!(nulls.len(), 3); + + let bitmap = Bitmap::from_null_buffer(nulls.clone()); + assert_eq!(nulls.null_count(), bitmap.unset_bits()); + assert_eq!(nulls.len(), bitmap.len()); + let back = NullBuffer::from(bitmap); + assert_eq!(nulls, back); +} diff --git a/crates/polars/tests/it/arrow/bitmap/mod.rs b/crates/polars/tests/it/arrow/bitmap/mod.rs new file mode 100644 index 0000000000000..88758695b7620 --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/mod.rs @@ -0,0 +1,124 @@ +mod assign_ops; +mod bitmap_ops; +mod immutable; +mod mutable; +mod utils; + +use arrow::bitmap::Bitmap; +use proptest::prelude::*; + +/// Returns a strategy of an arbitrary sliced [`Bitmap`] of size up to 1000 +pub(crate) fn bitmap_strategy() -> impl Strategy { + prop::collection::vec(any::(), 1..1000) + .prop_flat_map(|vec| { + let len = vec.len(); + (Just(vec), 0..len) + }) + .prop_flat_map(|(vec, index)| { + let len = vec.len(); + (Just(vec), Just(index), 0..len - index) + }) + .prop_flat_map(|(vec, index, len)| { + let bitmap = Bitmap::from(&vec); + let bitmap = bitmap.sliced(index, len); + Just(bitmap) + }) +} + +fn create_bitmap>(bytes: P, len: usize) -> Bitmap { + let buffer = Vec::::from(bytes.as_ref()); + Bitmap::from_u8_vec(buffer, len) +} + +#[test] +fn eq() { + let lhs = create_bitmap([0b01101010], 8); + let rhs = create_bitmap([0b01001110], 8); + assert!(lhs != rhs); +} + +#[test] +fn eq_len() { + let lhs = create_bitmap([0b01101010], 6); + let rhs = create_bitmap([0b00101010], 6); + assert!(lhs == rhs); + let rhs = create_bitmap([0b00001010], 6); + assert!(lhs != rhs); +} + +#[test] +fn eq_slice() { + let lhs = create_bitmap([0b10101010], 8).sliced(1, 7); + let rhs = create_bitmap([0b10101011], 8).sliced(1, 7); + assert!(lhs == rhs); + + let lhs = create_bitmap([0b10101010], 8).sliced(2, 6); + let rhs = create_bitmap([0b10101110], 8).sliced(2, 6); + assert!(lhs != rhs); +} + +#[test] +fn and() { + let lhs = create_bitmap([0b01101010], 8); + let rhs = create_bitmap([0b01001110], 8); + let expected = create_bitmap([0b01001010], 8); + assert_eq!(&lhs & &rhs, expected); +} + +#[test] +fn or_large() { + let input: &[u8] = &[ + 0b00000000, 0b00000001, 0b00000010, 0b00000100, 0b00001000, 0b00010000, 0b00100000, + 0b01000010, 0b11111111, + ]; + let input1: &[u8] = &[ + 0b00000000, 0b00000001, 0b10000000, 0b10000000, 0b10000000, 0b10000000, 0b10000000, + 0b10000000, 0b11111111, + ]; + let expected: &[u8] = &[ + 0b00000000, 0b00000001, 0b10000010, 0b10000100, 0b10001000, 0b10010000, 0b10100000, + 0b11000010, 0b11111111, + ]; + + let lhs = create_bitmap(input, 62); + let rhs = create_bitmap(input1, 62); + let expected = create_bitmap(expected, 62); + assert_eq!(&lhs | &rhs, expected); +} + +#[test] +fn and_offset() { + let lhs = create_bitmap([0b01101011], 8).sliced(1, 7); + let rhs = create_bitmap([0b01001111], 8).sliced(1, 7); + let expected = create_bitmap([0b01001010], 8).sliced(1, 7); + assert_eq!(&lhs & &rhs, expected); +} + +#[test] +fn or() { + let lhs = create_bitmap([0b01101010], 8); + let rhs = create_bitmap([0b01001110], 8); + let expected = create_bitmap([0b01101110], 8); + assert_eq!(&lhs | &rhs, expected); +} + +#[test] +fn not() { + let lhs = create_bitmap([0b01101010], 6); + let expected = create_bitmap([0b00010101], 6); + assert_eq!(!&lhs, expected); +} + +#[test] +fn subslicing_gives_correct_null_count() { + let base = Bitmap::from([false, true, true, false, false, true, true, true]); + assert_eq!(base.unset_bits(), 3); + + let view1 = base.clone().sliced(0, 1); + let view2 = base.sliced(1, 7); + assert_eq!(view1.unset_bits(), 1); + assert_eq!(view2.unset_bits(), 2); + + let view3 = view2.sliced(0, 1); + assert_eq!(view3.unset_bits(), 0); +} diff --git a/crates/polars/tests/it/arrow/bitmap/mutable.rs b/crates/polars/tests/it/arrow/bitmap/mutable.rs new file mode 100644 index 0000000000000..af37d634a4687 --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/mutable.rs @@ -0,0 +1,437 @@ +use arrow::bitmap::{Bitmap, MutableBitmap}; + +#[test] +fn from_slice() { + let slice = &[true, false, true]; + let a = MutableBitmap::from(slice); + assert_eq!(a.iter().collect::>(), slice); +} + +#[test] +fn from_len_zeroed() { + let a = MutableBitmap::from_len_zeroed(10); + assert_eq!(a.len(), 10); + assert_eq!(a.unset_bits(), 10); +} + +#[test] +fn from_len_set() { + let a = MutableBitmap::from_len_set(10); + assert_eq!(a.len(), 10); + assert_eq!(a.unset_bits(), 0); +} + +#[test] +fn try_new_invalid() { + assert!(MutableBitmap::try_new(vec![], 2).is_err()); +} + +#[test] +fn clear() { + let mut a = MutableBitmap::from_len_zeroed(10); + a.clear(); + assert_eq!(a.len(), 0); +} + +#[test] +fn trusted_len() { + let data = vec![true; 65]; + let bitmap = MutableBitmap::from_trusted_len_iter(data.into_iter()); + let bitmap: Bitmap = bitmap.into(); + assert_eq!(bitmap.len(), 65); + + assert_eq!(bitmap.as_slice().0[8], 0b00000001); +} + +#[test] +fn trusted_len_small() { + let data = vec![true; 7]; + let bitmap = MutableBitmap::from_trusted_len_iter(data.into_iter()); + let bitmap: Bitmap = bitmap.into(); + assert_eq!(bitmap.len(), 7); + + assert_eq!(bitmap.as_slice().0[0], 0b01111111); +} + +#[test] +fn push() { + let mut bitmap = MutableBitmap::new(); + bitmap.push(true); + bitmap.push(false); + bitmap.push(false); + for _ in 0..7 { + bitmap.push(true) + } + let bitmap: Bitmap = bitmap.into(); + assert_eq!(bitmap.len(), 10); + + assert_eq!(bitmap.as_slice().0, &[0b11111001, 0b00000011]); +} + +#[test] +fn push_small() { + let mut bitmap = MutableBitmap::new(); + bitmap.push(true); + bitmap.push(true); + bitmap.push(false); + let bitmap: Option = bitmap.into(); + let bitmap = bitmap.unwrap(); + assert_eq!(bitmap.len(), 3); + assert_eq!(bitmap.as_slice().0[0], 0b00000011); +} + +#[test] +fn push_exact_zeros() { + let mut bitmap = MutableBitmap::new(); + for _ in 0..8 { + bitmap.push(false) + } + let bitmap: Option = bitmap.into(); + let bitmap = bitmap.unwrap(); + assert_eq!(bitmap.len(), 8); + assert_eq!(bitmap.as_slice().0.len(), 1); +} + +#[test] +fn push_exact_ones() { + let mut bitmap = MutableBitmap::new(); + for _ in 0..8 { + bitmap.push(true) + } + let bitmap: Option = bitmap.into(); + assert!(bitmap.is_none()); +} + +#[test] +fn pop() { + let mut bitmap = MutableBitmap::new(); + bitmap.push(false); + bitmap.push(true); + bitmap.push(false); + bitmap.push(true); + + assert_eq!(bitmap.pop(), Some(true)); + assert_eq!(bitmap.len(), 3); + + assert_eq!(bitmap.pop(), Some(false)); + assert_eq!(bitmap.len(), 2); + + let bitmap: Bitmap = bitmap.into(); + assert_eq!(bitmap.len(), 2); + assert_eq!(bitmap.as_slice().0[0], 0b00001010); +} + +#[test] +fn pop_large() { + let mut bitmap = MutableBitmap::new(); + for _ in 0..8 { + bitmap.push(true); + } + + bitmap.push(false); + bitmap.push(true); + bitmap.push(false); + + assert_eq!(bitmap.pop(), Some(false)); + assert_eq!(bitmap.len(), 10); + + assert_eq!(bitmap.pop(), Some(true)); + assert_eq!(bitmap.len(), 9); + + assert_eq!(bitmap.pop(), Some(false)); + assert_eq!(bitmap.len(), 8); + + let bitmap: Bitmap = bitmap.into(); + assert_eq!(bitmap.len(), 8); + assert_eq!(bitmap.as_slice().0, &[0b11111111]); +} + +#[test] +fn pop_all() { + let mut bitmap = MutableBitmap::new(); + bitmap.push(false); + bitmap.push(true); + bitmap.push(true); + bitmap.push(true); + + assert_eq!(bitmap.pop(), Some(true)); + assert_eq!(bitmap.len(), 3); + assert_eq!(bitmap.pop(), Some(true)); + assert_eq!(bitmap.len(), 2); + assert_eq!(bitmap.pop(), Some(true)); + assert_eq!(bitmap.len(), 1); + assert_eq!(bitmap.pop(), Some(false)); + assert_eq!(bitmap.len(), 0); + assert_eq!(bitmap.pop(), None); + assert_eq!(bitmap.len(), 0); +} + +#[test] +fn capacity() { + let b = MutableBitmap::with_capacity(10); + assert!(b.capacity() >= 10); +} + +#[test] +fn capacity_push() { + let mut b = MutableBitmap::with_capacity(512); + (0..512).for_each(|_| b.push(true)); + assert_eq!(b.capacity(), 512); + b.reserve(8); + assert_eq!(b.capacity(), 1024); +} + +#[test] +fn extend() { + let mut b = MutableBitmap::new(); + + let iter = (0..512).map(|i| i % 6 == 0); + unsafe { b.extend_from_trusted_len_iter_unchecked(iter) }; + let b: Bitmap = b.into(); + for (i, v) in b.iter().enumerate() { + assert_eq!(i % 6 == 0, v); + } +} + +#[test] +fn extend_offset() { + let mut b = MutableBitmap::new(); + b.push(true); + + let iter = (0..512).map(|i| i % 6 == 0); + unsafe { b.extend_from_trusted_len_iter_unchecked(iter) }; + let b: Bitmap = b.into(); + let mut iter = b.iter().enumerate(); + assert!(iter.next().unwrap().1); + for (i, v) in iter { + assert_eq!((i - 1) % 6 == 0, v); + } +} + +#[test] +fn set() { + let mut bitmap = MutableBitmap::from_len_zeroed(12); + bitmap.set(0, true); + assert!(bitmap.get(0)); + bitmap.set(0, false); + assert!(!bitmap.get(0)); + + bitmap.set(11, true); + assert!(bitmap.get(11)); + bitmap.set(11, false); + assert!(!bitmap.get(11)); + bitmap.set(11, true); + + let bitmap: Option = bitmap.into(); + let bitmap = bitmap.unwrap(); + assert_eq!(bitmap.len(), 12); + assert_eq!(bitmap.as_slice().0[0], 0b00000000); +} + +#[test] +fn extend_from_bitmap() { + let other = Bitmap::from(&[true, false, true]); + let mut bitmap = MutableBitmap::new(); + + // call is optimized to perform a memcopy + bitmap.extend_from_bitmap(&other); + + assert_eq!(bitmap.len(), 3); + assert_eq!(bitmap.as_slice()[0], 0b00000101); + + // this call iterates over all bits + bitmap.extend_from_bitmap(&other); + + assert_eq!(bitmap.len(), 6); + assert_eq!(bitmap.as_slice()[0], 0b00101101); +} + +#[test] +fn extend_from_bitmap_offset() { + let other = Bitmap::from_u8_slice([0b00111111], 8); + let mut bitmap = MutableBitmap::from_vec(vec![1, 0, 0b00101010], 22); + + // call is optimized to perform a memcopy + bitmap.extend_from_bitmap(&other); + + assert_eq!(bitmap.len(), 22 + 8); + assert_eq!(bitmap.as_slice(), &[1, 0, 0b11101010, 0b00001111]); + + // more than one byte + let other = Bitmap::from_u8_slice([0b00111111, 0b00001111, 0b0001100], 20); + let mut bitmap = MutableBitmap::from_vec(vec![1, 0, 0b00101010], 22); + + // call is optimized to perform a memcopy + bitmap.extend_from_bitmap(&other); + + assert_eq!(bitmap.len(), 22 + 20); + assert_eq!( + bitmap.as_slice(), + &[1, 0, 0b11101010, 0b11001111, 0b0000011, 0b0000011] + ); +} + +#[test] +fn debug() { + let mut b = MutableBitmap::new(); + assert_eq!(format!("{b:?}"), "[]"); + b.push(true); + b.push(false); + assert_eq!(format!("{b:?}"), "[0b______01]"); + b.push(false); + b.push(false); + b.push(false); + b.push(false); + b.push(true); + b.push(true); + assert_eq!(format!("{b:?}"), "[0b11000001]"); + b.push(true); + assert_eq!(format!("{b:?}"), "[0b11000001, 0b_______1]"); +} + +#[test] +fn extend_set() { + let mut b = MutableBitmap::new(); + b.extend_constant(6, true); + assert_eq!(b.as_slice(), &[0b11111111]); + assert_eq!(b.len(), 6); + + let mut b = MutableBitmap::from(&[false]); + b.extend_constant(6, true); + assert_eq!(b.as_slice(), &[0b01111110]); + assert_eq!(b.len(), 1 + 6); + + let mut b = MutableBitmap::from(&[false]); + b.extend_constant(9, true); + assert_eq!(b.as_slice(), &[0b11111110, 0b11111111]); + assert_eq!(b.len(), 1 + 9); + + let mut b = MutableBitmap::from(&[false, false, false, false]); + b.extend_constant(2, true); + assert_eq!(b.as_slice(), &[0b00110000]); + assert_eq!(b.len(), 4 + 2); + + let mut b = MutableBitmap::from(&[false, false, false, false]); + b.extend_constant(8, true); + assert_eq!(b.as_slice(), &[0b11110000, 0b11111111]); + assert_eq!(b.len(), 4 + 8); + + let mut b = MutableBitmap::from(&[true, true]); + b.extend_constant(3, true); + assert_eq!(b.as_slice(), &[0b00011111]); + assert_eq!(b.len(), 2 + 3); +} + +#[test] +fn extend_unset() { + let mut b = MutableBitmap::new(); + b.extend_constant(6, false); + assert_eq!(b.as_slice(), &[0b0000000]); + assert_eq!(b.len(), 6); + + let mut b = MutableBitmap::from(&[true]); + b.extend_constant(6, false); + assert_eq!(b.as_slice(), &[0b00000001]); + assert_eq!(b.len(), 1 + 6); + + let mut b = MutableBitmap::from(&[true]); + b.extend_constant(9, false); + assert_eq!(b.as_slice(), &[0b0000001, 0b00000000]); + assert_eq!(b.len(), 1 + 9); + + let mut b = MutableBitmap::from(&[true, true, true, true]); + b.extend_constant(2, false); + assert_eq!(b.as_slice(), &[0b00001111]); + assert_eq!(b.len(), 4 + 2); +} + +#[test] +fn extend_bitmap() { + let mut b = MutableBitmap::from(&[true]); + b.extend_from_slice(&[0b00011001], 0, 6); + assert_eq!(b.as_slice(), &[0b00110011]); + assert_eq!(b.len(), 1 + 6); + + let mut b = MutableBitmap::from(&[true]); + b.extend_from_slice(&[0b00011001, 0b00011001], 0, 9); + assert_eq!(b.as_slice(), &[0b00110011, 0b00110010]); + assert_eq!(b.len(), 1 + 9); + + let mut b = MutableBitmap::from(&[true, true, true, true]); + b.extend_from_slice(&[0b00011001, 0b00011001], 0, 9); + assert_eq!(b.as_slice(), &[0b10011111, 0b10010001]); + assert_eq!(b.len(), 4 + 9); + + let mut b = MutableBitmap::from(&[true, true, true, true, true]); + b.extend_from_slice(&[0b00001011], 0, 4); + assert_eq!(b.as_slice(), &[0b01111111, 0b00000001]); + assert_eq!(b.len(), 5 + 4); +} + +// TODO! undo miri ignore once issue is fixed in miri +// this test was a memory hog and lead to OOM in CI +// given enough memory it was able to pass successfully on a local +#[test] +#[cfg_attr(miri, ignore)] +fn extend_constant1() { + use std::iter::FromIterator; + for i in 0..64 { + for j in 0..64 { + let mut b = MutableBitmap::new(); + b.extend_constant(i, false); + b.extend_constant(j, true); + assert_eq!( + b, + MutableBitmap::from_iter( + std::iter::repeat(false) + .take(i) + .chain(std::iter::repeat(true).take(j)) + ) + ); + + let mut b = MutableBitmap::new(); + b.extend_constant(i, true); + b.extend_constant(j, false); + assert_eq!( + b, + MutableBitmap::from_iter( + std::iter::repeat(true) + .take(i) + .chain(std::iter::repeat(false).take(j)) + ) + ); + } + } +} + +#[test] +fn extend_bitmap_one() { + for offset in 0..7 { + let mut b = MutableBitmap::new(); + for _ in 0..4 { + b.extend_from_slice(&[!0], offset, 1); + b.extend_from_slice(&[!0], offset, 1); + } + assert_eq!(b.as_slice(), &[0b11111111]); + } +} + +#[test] +fn extend_bitmap_other() { + let mut a = MutableBitmap::from([true, true, true, false, true, true, true, false, true, true]); + a.extend_from_slice(&[0b01111110u8, 0b10111111, 0b11011111, 0b00000111], 20, 2); + assert_eq!( + a, + MutableBitmap::from([ + true, true, true, false, true, true, true, false, true, true, true, false + ]) + ); +} + +#[test] +fn shrink_to_fit() { + let mut a = MutableBitmap::with_capacity(1025); + a.push(false); + a.shrink_to_fit(); + assert!(a.capacity() < 1025); +} diff --git a/crates/polars/tests/it/arrow/bitmap/utils/bit_chunks_exact.rs b/crates/polars/tests/it/arrow/bitmap/utils/bit_chunks_exact.rs new file mode 100644 index 0000000000000..104db7fdc3bbb --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/utils/bit_chunks_exact.rs @@ -0,0 +1,33 @@ +use arrow::bitmap::utils::BitChunksExact; + +#[test] +fn basics() { + let mut iter = BitChunksExact::::new(&[0b11111111u8, 0b00000001u8], 9); + assert_eq!(iter.next().unwrap(), 0b11111111u8); + assert_eq!(iter.remainder(), 0b00000001u8); +} + +#[test] +fn basics_u16_small() { + let mut iter = BitChunksExact::::new(&[0b11111111u8], 7); + assert_eq!(iter.next(), None); + assert_eq!(iter.remainder(), 0b0000_0000_1111_1111u16); +} + +#[test] +fn basics_u16() { + let mut iter = BitChunksExact::::new(&[0b11111111u8, 0b00000001u8], 9); + assert_eq!(iter.next(), None); + assert_eq!(iter.remainder(), 0b0000_0001_1111_1111u16); +} + +#[test] +fn remainder_u16() { + let mut iter = BitChunksExact::::new( + &[0b11111111u8, 0b00000001u8, 0b00000001u8, 0b11011011u8], + 23, + ); + assert_eq!(iter.next(), Some(511)); + assert_eq!(iter.next(), None); + assert_eq!(iter.remainder(), 1u16); +} diff --git a/crates/polars/tests/it/arrow/bitmap/utils/chunk_iter.rs b/crates/polars/tests/it/arrow/bitmap/utils/chunk_iter.rs new file mode 100644 index 0000000000000..d19b6e51b5eda --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/utils/chunk_iter.rs @@ -0,0 +1,163 @@ +use arrow::bitmap::utils::BitChunks; +use arrow::types::BitChunkIter; + +#[test] +fn basics() { + let mut iter = BitChunks::::new(&[0b00000001u8, 0b00000010u8], 0, 16); + assert_eq!(iter.next().unwrap(), 0b0000_0010_0000_0001u16); + assert_eq!(iter.remainder(), 0); +} + +#[test] +fn remainder() { + let a = BitChunks::::new(&[0b00000001u8, 0b00000010u8, 0b00000100u8], 0, 18); + assert_eq!(a.remainder(), 0b00000100u16); +} + +#[test] +fn remainder_saturating() { + let a = BitChunks::::new(&[0b00000001u8, 0b00000010u8, 0b00000010u8], 0, 18); + assert_eq!(a.remainder(), 0b0000_0000_0000_0010u16); +} + +#[test] +fn basics_offset() { + let mut iter = BitChunks::::new(&[0b00000001u8, 0b00000011u8, 0b00000001u8], 1, 16); + assert_eq!(iter.remainder(), 0); + assert_eq!(iter.next().unwrap(), 0b1000_0001_1000_0000u16); + assert_eq!(iter.next(), None); +} + +#[test] +fn basics_offset_remainder() { + let mut a = BitChunks::::new(&[0b00000001u8, 0b00000011u8, 0b10000001u8], 1, 15); + assert_eq!(a.next(), None); + assert_eq!(a.remainder(), 0b1000_0001_1000_0000u16); + assert_eq!(a.remainder_len(), 15); +} + +#[test] +fn offset_remainder_saturating() { + let a = BitChunks::::new(&[0b00000001u8, 0b00000011u8, 0b00000011u8], 1, 17); + assert_eq!(a.remainder(), 0b0000_0000_0000_0001u16); +} + +#[test] +fn offset_remainder_saturating2() { + let a = BitChunks::::new(&[0b01001001u8, 0b00000001], 1, 8); + assert_eq!(a.remainder(), 0b1010_0100u64); +} + +#[test] +fn offset_remainder_saturating3() { + let input: &[u8] = &[0b01000000, 0b01000001]; + let a = BitChunks::::new(input, 8, 2); + assert_eq!(a.remainder(), 0b0100_0001u64); +} + +#[test] +fn basics_multiple() { + let mut iter = BitChunks::::new( + &[0b00000001u8, 0b00000010u8, 0b00000100u8, 0b00001000u8], + 0, + 4 * 8, + ); + assert_eq!(iter.next().unwrap(), 0b0000_0010_0000_0001u16); + assert_eq!(iter.next().unwrap(), 0b0000_1000_0000_0100u16); + assert_eq!(iter.remainder(), 0); +} + +#[test] +fn basics_multiple_offset() { + let mut iter = BitChunks::::new( + &[ + 0b00000001u8, + 0b00000010u8, + 0b00000100u8, + 0b00001000u8, + 0b00000001u8, + ], + 1, + 4 * 8, + ); + assert_eq!(iter.next().unwrap(), 0b0000_0001_0000_0000u16); + assert_eq!(iter.next().unwrap(), 0b1000_0100_0000_0010u16); + assert_eq!(iter.remainder(), 0); +} + +#[test] +fn remainder_large() { + let input: &[u8] = &[ + 0b00100100, 0b01001001, 0b10010010, 0b00100100, 0b01001001, 0b10010010, 0b00100100, + 0b01001001, 0b10010010, 0b00100100, 0b01001001, 0b10010010, 0b00000100, + ]; + let mut iter = BitChunks::::new(input, 0, 8 * 12 + 4); + assert_eq!(iter.remainder_len(), 100 - 96); + + for j in 0..12 { + let mut a = BitChunkIter::new(iter.next().unwrap(), 8); + for i in 0..8 { + assert_eq!(a.next().unwrap(), (j * 8 + i + 1) % 3 == 0); + } + } + assert_eq!(None, iter.next()); + + let expected_remainder = 0b00000100u8; + assert_eq!(iter.remainder(), expected_remainder); + + let mut a = BitChunkIter::new(expected_remainder, 8); + for i in 0..4 { + assert_eq!(a.next().unwrap(), (i + 1) % 3 == 0); + } +} + +#[test] +fn basics_1() { + let mut iter = BitChunks::::new( + &[0b00000001u8, 0b00000010u8, 0b00000100u8, 0b00001000u8], + 8, + 3 * 8, + ); + assert_eq!(iter.next().unwrap(), 0b0000_0100_0000_0010u16); + assert_eq!(iter.next(), None); + assert_eq!(iter.remainder(), 0b0000_0000_0000_1000u16); + assert_eq!(iter.remainder_len(), 8); +} + +#[test] +fn basics_2() { + let mut iter = BitChunks::::new( + &[0b00000001u8, 0b00000010u8, 0b00000100u8, 0b00001000u8], + 7, + 3 * 8, + ); + assert_eq!(iter.remainder(), 0b0000_0000_0001_0000u16); + assert_eq!(iter.next().unwrap(), 0b0000_1000_0000_0100u16); + assert_eq!(iter.next(), None); +} + +#[test] +fn remainder_1() { + let mut iter = BitChunks::::new(&[0b11111111u8, 0b00000001u8], 0, 9); + assert_eq!(iter.next(), None); + assert_eq!(iter.remainder(), 0b1_1111_1111u64); +} + +#[test] +fn remainder_2() { + // (i % 3 == 0) in bitmap + let input: &[u8] = &[ + 0b01001001, 0b10010010, 0b00100100, 0b01001001, 0b10010010, 0b00100100, 0b01001001, + 0b10010010, 0b00100100, 0b01001001, /* 73 */ + 0b10010010, /* 146 */ + 0b00100100, 0b00001001, + ]; + let offset = 10; // 8 + 2 + let length = 90; + + let mut iter = BitChunks::::new(input, offset, length); + let first: u64 = 0b0100100100100100100100100100100100100100100100100100100100100100; + assert_eq!(first, iter.next().unwrap()); + assert_eq!(iter.next(), None); + assert_eq!(iter.remainder(), 0b10010010010010010010010010u64); +} diff --git a/crates/polars/tests/it/arrow/bitmap/utils/fmt.rs b/crates/polars/tests/it/arrow/bitmap/utils/fmt.rs new file mode 100644 index 0000000000000..b07c50db90111 --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/utils/fmt.rs @@ -0,0 +1,40 @@ +use arrow::bitmap::utils::fmt; + +struct A<'a>(&'a [u8], usize, usize); + +impl<'a> std::fmt::Debug for A<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fmt(self.0, self.1, self.2, f) + } +} + +#[test] +fn test_debug() -> std::fmt::Result { + assert_eq!(format!("{:?}", A(&[1], 0, 0)), "[]"); + assert_eq!(format!("{:?}", A(&[0b11000001], 0, 8)), "[0b11000001]"); + assert_eq!( + format!("{:?}", A(&[0b11000001, 1], 0, 9)), + "[0b11000001, 0b_______1]" + ); + assert_eq!(format!("{:?}", A(&[1], 0, 2)), "[0b______01]"); + assert_eq!(format!("{:?}", A(&[1], 1, 2)), "[0b_____00_]"); + assert_eq!(format!("{:?}", A(&[1], 2, 2)), "[0b____00__]"); + assert_eq!(format!("{:?}", A(&[1], 3, 2)), "[0b___00___]"); + assert_eq!(format!("{:?}", A(&[1], 4, 2)), "[0b__00____]"); + assert_eq!(format!("{:?}", A(&[1], 5, 2)), "[0b_00_____]"); + assert_eq!(format!("{:?}", A(&[1], 6, 2)), "[0b00______]"); + assert_eq!( + format!("{:?}", A(&[0b11000001, 1], 1, 9)), + "[0b1100000_, 0b______01]" + ); + // extra bytes are ignored + assert_eq!( + format!("{:?}", A(&[0b11000001, 1, 1, 1], 1, 9)), + "[0b1100000_, 0b______01]" + ); + assert_eq!( + format!("{:?}", A(&[0b11000001, 1, 1], 2, 16)), + "[0b110000__, 0b00000001, 0b______01]" + ); + Ok(()) +} diff --git a/crates/polars/tests/it/arrow/bitmap/utils/iterator.rs b/crates/polars/tests/it/arrow/bitmap/utils/iterator.rs new file mode 100644 index 0000000000000..184a428f137b5 --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/utils/iterator.rs @@ -0,0 +1,44 @@ +use arrow::bitmap::utils::BitmapIter; + +#[test] +fn basic() { + let values = &[0b01011011u8]; + let iter = BitmapIter::new(values, 0, 6); + let result = iter.collect::>(); + assert_eq!(result, vec![true, true, false, true, true, false]) +} + +#[test] +fn large() { + let values = &[0b01011011u8]; + let values = std::iter::repeat(values) + .take(63) + .flatten() + .copied() + .collect::>(); + let len = 63 * 8; + let iter = BitmapIter::new(&values, 0, len); + assert_eq!(iter.count(), len); +} + +#[test] +fn offset() { + let values = &[0b01011011u8]; + let iter = BitmapIter::new(values, 2, 4); + let result = iter.collect::>(); + assert_eq!(result, vec![false, true, true, false]) +} + +#[test] +fn rev() { + let values = &[0b01011011u8, 0b01011011u8]; + let iter = BitmapIter::new(values, 2, 13); + let result = iter.rev().collect::>(); + assert_eq!( + result, + vec![false, true, true, false, true, false, true, true, false, true, true, false, true] + .into_iter() + .rev() + .collect::>() + ) +} diff --git a/crates/polars/tests/it/arrow/bitmap/utils/mod.rs b/crates/polars/tests/it/arrow/bitmap/utils/mod.rs new file mode 100644 index 0000000000000..12af43e4e9497 --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/utils/mod.rs @@ -0,0 +1,83 @@ +use arrow::bitmap::utils::*; +use proptest::prelude::*; + +use super::bitmap_strategy; + +mod bit_chunks_exact; +mod chunk_iter; +mod fmt; +mod iterator; +mod slice_iterator; +mod zip_validity; + +#[test] +fn get_bit_basics() { + let input: &[u8] = &[ + 0b00000000, 0b00000001, 0b00000010, 0b00000100, 0b00001000, 0b00010000, 0b00100000, + 0b01000000, 0b11111111, + ]; + for i in 0..8 { + assert!(!get_bit(input, i)); + } + assert!(get_bit(input, 8)); + for i in 8 + 1..2 * 8 { + assert!(!get_bit(input, i)); + } + assert!(get_bit(input, 2 * 8 + 1)); + for i in 2 * 8 + 2..3 * 8 { + assert!(!get_bit(input, i)); + } + assert!(get_bit(input, 3 * 8 + 2)); + for i in 3 * 8 + 3..4 * 8 { + assert!(!get_bit(input, i)); + } + assert!(get_bit(input, 4 * 8 + 3)); +} + +#[test] +fn count_zeros_basics() { + let input: &[u8] = &[ + 0b01001001, 0b00000001, 0b00000010, 0b00000100, 0b00001000, 0b00010000, 0b00100000, + 0b01000000, 0b11111111, + ]; + assert_eq!(count_zeros(input, 0, 8), 8 - 3); + assert_eq!(count_zeros(input, 1, 7), 7 - 2); + assert_eq!(count_zeros(input, 1, 8), 8 - 3); + assert_eq!(count_zeros(input, 2, 7), 7 - 3); + assert_eq!(count_zeros(input, 0, 32), 32 - 6); + assert_eq!(count_zeros(input, 9, 2), 2); + + let input: &[u8] = &[0b01000000, 0b01000001]; + assert_eq!(count_zeros(input, 8, 2), 1); + assert_eq!(count_zeros(input, 8, 3), 2); + assert_eq!(count_zeros(input, 8, 4), 3); + assert_eq!(count_zeros(input, 8, 5), 4); + assert_eq!(count_zeros(input, 8, 6), 5); + assert_eq!(count_zeros(input, 8, 7), 5); + assert_eq!(count_zeros(input, 8, 8), 6); + + let input: &[u8] = &[0b01000000, 0b01010101]; + assert_eq!(count_zeros(input, 9, 2), 1); + assert_eq!(count_zeros(input, 10, 2), 1); + assert_eq!(count_zeros(input, 11, 2), 1); + assert_eq!(count_zeros(input, 12, 2), 1); + assert_eq!(count_zeros(input, 13, 2), 1); + assert_eq!(count_zeros(input, 14, 2), 1); +} + +#[test] +fn count_zeros_1() { + // offset = 10, len = 90 => remainder + let input: &[u8] = &[73, 146, 36, 73, 146, 36, 73, 146, 36, 73, 146, 36, 9]; + assert_eq!(count_zeros(input, 10, 90), 60); +} + +proptest! { + /// Asserts that `Bitmap::null_count` equals the number of unset bits + #[test] + #[cfg_attr(miri, ignore)] // miri and proptest do not work well :( + fn null_count(bitmap in bitmap_strategy()) { + let sum_of_sets: usize = (0..bitmap.len()).map(|x| (!bitmap.get_bit(x)) as usize).sum(); + assert_eq!(bitmap.unset_bits(), sum_of_sets); + } +} diff --git a/crates/polars/tests/it/arrow/bitmap/utils/slice_iterator.rs b/crates/polars/tests/it/arrow/bitmap/utils/slice_iterator.rs new file mode 100644 index 0000000000000..4a0d024643ecc --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/utils/slice_iterator.rs @@ -0,0 +1,150 @@ +use arrow::bitmap::utils::SlicesIterator; +use arrow::bitmap::Bitmap; +use proptest::prelude::*; + +use super::bitmap_strategy; + +proptest! { + /// Asserts that: + /// * `slots` is the number of set bits in the bitmap + /// * the sum of the lens of the slices equals `slots` + /// * each item on each slice is set + #[test] + #[cfg_attr(miri, ignore)] // miri and proptest do not work well :( + fn check_invariants(bitmap in bitmap_strategy()) { + let iter = SlicesIterator::new(&bitmap); + + let slots = iter.slots(); + + assert_eq!(bitmap.len() - bitmap.unset_bits(), slots); + + let slices = iter.collect::>(); + let mut sum = 0; + for (start, len) in slices { + sum += len; + for i in start..(start+len) { + assert!(bitmap.get_bit(i)); + } + } + assert_eq!(sum, slots); + } +} + +#[test] +fn single_set() { + let values = (0..16).map(|i| i == 1).collect::(); + + let iter = SlicesIterator::new(&values); + let count = iter.slots(); + let chunks = iter.collect::>(); + + assert_eq!(chunks, vec![(1, 1)]); + assert_eq!(count, 1); +} + +#[test] +fn single_unset() { + let values = (0..64).map(|i| i != 1).collect::(); + + let iter = SlicesIterator::new(&values); + let count = iter.slots(); + let chunks = iter.collect::>(); + + assert_eq!(chunks, vec![(0, 1), (2, 62)]); + assert_eq!(count, 64 - 1); +} + +#[test] +fn generic() { + let values = (0..130).map(|i| i % 62 != 0).collect::(); + + let iter = SlicesIterator::new(&values); + let count = iter.slots(); + let chunks = iter.collect::>(); + + assert_eq!(chunks, vec![(1, 61), (63, 61), (125, 5)]); + assert_eq!(count, 61 + 61 + 5); +} + +#[test] +fn incomplete_byte() { + let values = (0..6).map(|i| i == 1).collect::(); + + let iter = SlicesIterator::new(&values); + let count = iter.slots(); + let chunks = iter.collect::>(); + + assert_eq!(chunks, vec![(1, 1)]); + assert_eq!(count, 1); +} + +#[test] +fn incomplete_byte1() { + let values = (0..12).map(|i| i == 9).collect::(); + + let iter = SlicesIterator::new(&values); + let count = iter.slots(); + let chunks = iter.collect::>(); + + assert_eq!(chunks, vec![(9, 1)]); + assert_eq!(count, 1); +} + +#[test] +fn end_of_byte() { + let values = (0..16).map(|i| i != 7).collect::(); + + let iter = SlicesIterator::new(&values); + let count = iter.slots(); + let chunks = iter.collect::>(); + + assert_eq!(chunks, vec![(0, 7), (8, 8)]); + assert_eq!(count, 15); +} + +#[test] +fn bla() { + let values = vec![true, true, true, true, true, true, true, false] + .into_iter() + .collect::(); + let iter = SlicesIterator::new(&values); + let count = iter.slots(); + assert_eq!(values.unset_bits() + iter.slots(), values.len()); + + let total = iter.into_iter().fold(0, |acc, x| acc + x.1); + + assert_eq!(count, total); +} + +#[test] +fn past_end_should_not_be_returned() { + let values = Bitmap::from_u8_slice([0b11111010], 3); + let iter = SlicesIterator::new(&values); + let count = iter.slots(); + assert_eq!(values.unset_bits() + iter.slots(), values.len()); + + let total = iter.into_iter().fold(0, |acc, x| acc + x.1); + + assert_eq!(count, total); +} + +#[test] +fn sliced() { + let values = Bitmap::from_u8_slice([0b11111010, 0b11111011], 16); + let values = values.sliced(8, 2); + let iter = SlicesIterator::new(&values); + + let chunks = iter.collect::>(); + + // the first "11" in the second byte + assert_eq!(chunks, vec![(0, 2)]); +} + +#[test] +fn remainder_1() { + let values = Bitmap::from_u8_slice([0, 0, 0b00000000, 0b00010101], 27); + let values = values.sliced(22, 5); + let iter = SlicesIterator::new(&values); + let chunks = iter.collect::>(); + assert_eq!(chunks, vec![(2, 1), (4, 1)]); +} diff --git a/crates/polars/tests/it/arrow/bitmap/utils/zip_validity.rs b/crates/polars/tests/it/arrow/bitmap/utils/zip_validity.rs new file mode 100644 index 0000000000000..a12dedaa43d96 --- /dev/null +++ b/crates/polars/tests/it/arrow/bitmap/utils/zip_validity.rs @@ -0,0 +1,106 @@ +use arrow::bitmap::utils::{BitmapIter, ZipValidity}; +use arrow::bitmap::Bitmap; + +#[test] +fn basic() { + let a = Bitmap::from([true, false]); + let a = Some(a.iter()); + let values = vec![0, 1]; + let zip = ZipValidity::new(values.into_iter(), a); + + let a = zip.collect::>(); + assert_eq!(a, vec![Some(0), None]); +} + +#[test] +fn complete() { + let a = Bitmap::from([true, false, true, false, true, false, true, false]); + let a = Some(a.iter()); + let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let zip = ZipValidity::new(values.into_iter(), a); + + let a = zip.collect::>(); + assert_eq!( + a, + vec![Some(0), None, Some(2), None, Some(4), None, Some(6), None] + ); +} + +#[test] +fn slices() { + let a = Bitmap::from([true, false]); + let a = Some(a.iter()); + let offsets = [0, 2, 3]; + let values = [1, 2, 3]; + let iter = offsets.windows(2).map(|x| { + let start = x[0]; + let end = x[1]; + &values[start..end] + }); + let zip = ZipValidity::new(iter, a); + + let a = zip.collect::>(); + assert_eq!(a, vec![Some([1, 2].as_ref()), None]); +} + +#[test] +fn byte() { + let a = Bitmap::from([true, false, true, false, false, true, true, false, true]); + let a = Some(a.iter()); + let values = vec![0, 1, 2, 3, 4, 5, 6, 7, 8]; + let zip = ZipValidity::new(values.into_iter(), a); + + let a = zip.collect::>(); + assert_eq!( + a, + vec![ + Some(0), + None, + Some(2), + None, + None, + Some(5), + Some(6), + None, + Some(8) + ] + ); +} + +#[test] +fn offset() { + let a = Bitmap::from([true, false, true, false, false, true, true, false, true]).sliced(1, 8); + let a = Some(a.iter()); + let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let zip = ZipValidity::new(values.into_iter(), a); + + let a = zip.collect::>(); + assert_eq!( + a, + vec![None, Some(1), None, None, Some(4), Some(5), None, Some(7)] + ); +} + +#[test] +fn none() { + let values = vec![0, 1, 2]; + let zip = ZipValidity::new(values.into_iter(), None::); + + let a = zip.collect::>(); + assert_eq!(a, vec![Some(0), Some(1), Some(2)]); +} + +#[test] +fn rev() { + let a = Bitmap::from([true, false, true, false, false, true, true, false, true]).sliced(1, 8); + let a = Some(a.iter()); + let values = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let zip = ZipValidity::new(values.into_iter(), a); + + let result = zip.rev().collect::>(); + let expected = vec![None, Some(1), None, None, Some(4), Some(5), None, Some(7)] + .into_iter() + .rev() + .collect::>(); + assert_eq!(result, expected); +} diff --git a/crates/polars/tests/it/arrow/buffer/immutable.rs b/crates/polars/tests/it/arrow/buffer/immutable.rs new file mode 100644 index 0000000000000..aaf16ad8fa877 --- /dev/null +++ b/crates/polars/tests/it/arrow/buffer/immutable.rs @@ -0,0 +1,119 @@ +use arrow::buffer::Buffer; + +#[test] +fn new() { + let buffer = Buffer::::new(); + assert_eq!(buffer.len(), 0); + assert!(buffer.is_empty()); +} + +#[test] +fn from_slice() { + let buffer = Buffer::::from(vec![0, 1, 2]); + assert_eq!(buffer.len(), 3); + assert_eq!(buffer.as_slice(), &[0, 1, 2]); +} + +#[test] +fn slice() { + let buffer = Buffer::::from(vec![0, 1, 2, 3]); + let buffer = buffer.sliced(1, 2); + assert_eq!(buffer.len(), 2); + assert_eq!(buffer.as_slice(), &[1, 2]); +} + +#[test] +fn from_iter() { + let buffer = (0..3).collect::>(); + assert_eq!(buffer.len(), 3); + assert_eq!(buffer.as_slice(), &[0, 1, 2]); +} + +#[test] +fn debug() { + let buffer = Buffer::::from(vec![0, 1, 2, 3]); + let buffer = buffer.sliced(1, 2); + let a = format!("{buffer:?}"); + assert_eq!(a, "[1, 2]") +} + +#[test] +fn from_vec() { + let buffer = Buffer::::from(vec![0, 1, 2]); + assert_eq!(buffer.len(), 3); + assert_eq!(buffer.as_slice(), &[0, 1, 2]); +} + +#[test] +#[cfg(feature = "arrow")] +fn from_arrow() { + let buffer = arrow_buffer::Buffer::from_vec(vec![1_i32, 2_i32, 3_i32]); + let b = Buffer::::from(buffer.clone()); + assert_eq!(b.len(), 3); + assert_eq!(b.as_slice(), &[1, 2, 3]); + let back = arrow_buffer::Buffer::from(b); + assert_eq!(back, buffer); + + let buffer = buffer.slice(4); + let b = Buffer::::from(buffer.clone()); + assert_eq!(b.len(), 2); + assert_eq!(b.as_slice(), &[2, 3]); + let back = arrow_buffer::Buffer::from(b); + assert_eq!(back, buffer); + + let buffer = arrow_buffer::Buffer::from_vec(vec![1_i64, 2_i64]); + let b = Buffer::::from(buffer.clone()); + assert_eq!(b.len(), 4); + assert_eq!(b.as_slice(), &[1, 0, 2, 0]); + let back = arrow_buffer::Buffer::from(b); + assert_eq!(back, buffer); + + let buffer = buffer.slice(4); + let b = Buffer::::from(buffer.clone()); + assert_eq!(b.len(), 3); + assert_eq!(b.as_slice(), &[0, 2, 0]); + let back = arrow_buffer::Buffer::from(b); + assert_eq!(back, buffer); +} + +#[test] +#[cfg(feature = "arrow")] +fn from_arrow_vec() { + // Zero-copy vec conversion in arrow-rs + let buffer = arrow_buffer::Buffer::from_vec(vec![1_i32, 2_i32, 3_i32]); + let back: Vec = buffer.into_vec().unwrap(); + + // Zero-copy vec conversion in arrow2 + let buffer = Buffer::::from(back); + let back: Vec = buffer.into_mut().unwrap_right(); + + let buffer = arrow_buffer::Buffer::from_vec(back); + let buffer = Buffer::::from(buffer); + + // But not possible after conversion between buffer representations + let _ = buffer.into_mut().unwrap_left(); + + let buffer = Buffer::::from(vec![1_i32]); + let buffer = arrow_buffer::Buffer::from(buffer); + + // But not possible after conversion between buffer representations + let _ = buffer.into_vec::().unwrap_err(); +} + +#[test] +#[cfg(feature = "arrow")] +#[should_panic(expected = "not aligned")] +fn from_arrow_misaligned() { + let buffer = arrow_buffer::Buffer::from_vec(vec![1_i32, 2_i32, 3_i32]).slice(1); + let _ = Buffer::::from(buffer); +} + +#[test] +#[cfg(feature = "arrow")] +fn from_arrow_sliced() { + let buffer = arrow_buffer::Buffer::from_vec(vec![1_i32, 2_i32, 3_i32]); + let b = Buffer::::from(buffer); + let sliced = b.sliced(1, 2); + let back = arrow_buffer::Buffer::from(sliced); + assert_eq!(back.typed_data::(), &[2, 3]); +} diff --git a/crates/polars/tests/it/arrow/buffer/mod.rs b/crates/polars/tests/it/arrow/buffer/mod.rs new file mode 100644 index 0000000000000..723312cd1a873 --- /dev/null +++ b/crates/polars/tests/it/arrow/buffer/mod.rs @@ -0,0 +1 @@ +mod immutable; diff --git a/crates/polars/tests/it/arrow/compute/aggregate/memory.rs b/crates/polars/tests/it/arrow/compute/aggregate/memory.rs new file mode 100644 index 0000000000000..3f31240b86022 --- /dev/null +++ b/crates/polars/tests/it/arrow/compute/aggregate/memory.rs @@ -0,0 +1,32 @@ +use arrow::array::*; +use arrow::compute::aggregate::estimated_bytes_size; +use arrow::datatypes::{ArrowDataType, Field}; + +#[test] +fn primitive() { + let a = Int32Array::from_slice([1, 2, 3, 4, 5]); + assert_eq!(5 * std::mem::size_of::(), estimated_bytes_size(&a)); +} + +#[test] +fn boolean() { + let a = BooleanArray::from_slice([true]); + assert_eq!(1, estimated_bytes_size(&a)); +} + +#[test] +fn utf8() { + let a = Utf8Array::::from_slice(["aaa"]); + assert_eq!(3 + 2 * std::mem::size_of::(), estimated_bytes_size(&a)); +} + +#[test] +fn fixed_size_list() { + let data_type = ArrowDataType::FixedSizeList( + Box::new(Field::new("elem", ArrowDataType::Float32, false)), + 3, + ); + let values = Box::new(Float32Array::from_slice([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])); + let a = FixedSizeListArray::new(data_type, values, None); + assert_eq!(6 * std::mem::size_of::(), estimated_bytes_size(&a)); +} diff --git a/crates/polars/tests/it/arrow/compute/aggregate/mod.rs b/crates/polars/tests/it/arrow/compute/aggregate/mod.rs new file mode 100644 index 0000000000000..d7de8a8c37c5e --- /dev/null +++ b/crates/polars/tests/it/arrow/compute/aggregate/mod.rs @@ -0,0 +1,2 @@ +mod memory; +mod sum; diff --git a/crates/polars/tests/it/arrow/compute/aggregate/sum.rs b/crates/polars/tests/it/arrow/compute/aggregate/sum.rs new file mode 100644 index 0000000000000..011f75aad356f --- /dev/null +++ b/crates/polars/tests/it/arrow/compute/aggregate/sum.rs @@ -0,0 +1,37 @@ +use arrow::array::*; +use arrow::compute::aggregate::{sum, sum_primitive}; +use arrow::datatypes::ArrowDataType; +use arrow::scalar::{PrimitiveScalar, Scalar}; + +#[test] +fn test_primitive_array_sum() { + let a = Int32Array::from_slice([1, 2, 3, 4, 5]); + assert_eq!( + &PrimitiveScalar::::from(Some(15)) as &dyn Scalar, + sum(&a).unwrap().as_ref() + ); + + let a = a.to(ArrowDataType::Date32); + assert_eq!( + &PrimitiveScalar::::from(Some(15)).to(ArrowDataType::Date32) as &dyn Scalar, + sum(&a).unwrap().as_ref() + ); +} + +#[test] +fn test_primitive_array_float_sum() { + let a = Float64Array::from_slice([1.1f64, 2.2, 3.3, 4.4, 5.5]); + assert!((16.5 - sum_primitive(&a).unwrap()).abs() < f64::EPSILON); +} + +#[test] +fn test_primitive_array_sum_with_nulls() { + let a = Int32Array::from(&[None, Some(2), Some(3), None, Some(5)]); + assert_eq!(10, sum_primitive(&a).unwrap()); +} + +#[test] +fn test_primitive_array_sum_all_nulls() { + let a = Int32Array::from(&[None, None, None]); + assert_eq!(None, sum_primitive(&a)); +} diff --git a/crates/polars/tests/it/arrow/compute/arity_assign.rs b/crates/polars/tests/it/arrow/compute/arity_assign.rs new file mode 100644 index 0000000000000..b8ba89dda2383 --- /dev/null +++ b/crates/polars/tests/it/arrow/compute/arity_assign.rs @@ -0,0 +1,21 @@ +use arrow::array::Int32Array; +use arrow::compute::arity_assign::{binary, unary}; + +#[test] +fn test_unary_assign() { + let mut a = Int32Array::from([Some(5), Some(6), None, Some(10)]); + + unary(&mut a, |x| x + 10); + + assert_eq!(a, Int32Array::from([Some(15), Some(16), None, Some(20)])) +} + +#[test] +fn test_binary_assign() { + let mut a = Int32Array::from([Some(5), Some(6), None, Some(10)]); + let b = Int32Array::from([Some(1), Some(2), Some(1), None]); + + binary(&mut a, &b, |x, y| x + y); + + assert_eq!(a, Int32Array::from([Some(6), Some(8), None, None])) +} diff --git a/crates/polars/tests/it/arrow/compute/bitwise.rs b/crates/polars/tests/it/arrow/compute/bitwise.rs new file mode 100644 index 0000000000000..e2a380fbd7077 --- /dev/null +++ b/crates/polars/tests/it/arrow/compute/bitwise.rs @@ -0,0 +1,41 @@ +use arrow::array::*; +use arrow::compute::bitwise::*; + +#[test] +fn test_xor() { + let a = Int32Array::from(&[Some(2), Some(4), Some(6), Some(7)]); + let b = Int32Array::from(&[None, Some(6), Some(9), Some(7)]); + let result = xor(&a, &b); + let expected = Int32Array::from(&[None, Some(2), Some(15), Some(0)]); + + assert_eq!(result, expected); +} + +#[test] +fn test_and() { + let a = Int32Array::from(&[Some(1), Some(2), Some(15)]); + let b = Int32Array::from(&[None, Some(2), Some(6)]); + let result = and(&a, &b); + let expected = Int32Array::from(&[None, Some(2), Some(6)]); + + assert_eq!(result, expected); +} + +#[test] +fn test_or() { + let a = Int32Array::from(&[Some(1), Some(2), Some(0)]); + let b = Int32Array::from(&[None, Some(2), Some(0)]); + let result = or(&a, &b); + let expected = Int32Array::from(&[None, Some(2), Some(0)]); + + assert_eq!(result, expected); +} + +#[test] +fn test_not() { + let a = Int8Array::from(&[None, Some(1i8), Some(-100i8)]); + let result = not(&a); + let expected = Int8Array::from(&[None, Some(-2), Some(99)]); + + assert_eq!(result, expected); +} diff --git a/crates/polars/tests/it/arrow/compute/boolean.rs b/crates/polars/tests/it/arrow/compute/boolean.rs new file mode 100644 index 0000000000000..488a53b4732dd --- /dev/null +++ b/crates/polars/tests/it/arrow/compute/boolean.rs @@ -0,0 +1,453 @@ +use std::iter::FromIterator; + +use arrow::array::*; +use arrow::compute::boolean::*; +use arrow::scalar::BooleanScalar; + +#[test] +fn array_and() { + let a = BooleanArray::from_slice(vec![false, false, true, true]); + let b = BooleanArray::from_slice(vec![false, true, false, true]); + let c = and(&a, &b); + + let expected = BooleanArray::from_slice(vec![false, false, false, true]); + + assert_eq!(c, expected); +} + +#[test] +fn array_or() { + let a = BooleanArray::from_slice(vec![false, false, true, true]); + let b = BooleanArray::from_slice(vec![false, true, false, true]); + let c = or(&a, &b); + + let expected = BooleanArray::from_slice(vec![false, true, true, true]); + + assert_eq!(c, expected); +} + +#[test] +fn array_or_validity() { + let a = BooleanArray::from(vec![ + None, + None, + None, + Some(false), + Some(false), + Some(false), + Some(true), + Some(true), + Some(true), + ]); + let b = BooleanArray::from(vec![ + None, + Some(false), + Some(true), + None, + Some(false), + Some(true), + None, + Some(false), + Some(true), + ]); + let c = or(&a, &b); + + let expected = BooleanArray::from(vec![ + None, + None, + None, + None, + Some(false), + Some(true), + None, + Some(true), + Some(true), + ]); + + assert_eq!(c, expected); +} + +#[test] +fn array_not() { + let a = BooleanArray::from_slice(vec![false, true]); + let c = not(&a); + + let expected = BooleanArray::from_slice(vec![true, false]); + + assert_eq!(c, expected); +} + +#[test] +fn array_and_validity() { + let a = BooleanArray::from(vec![ + None, + None, + None, + Some(false), + Some(false), + Some(false), + Some(true), + Some(true), + Some(true), + ]); + let b = BooleanArray::from(vec![ + None, + Some(false), + Some(true), + None, + Some(false), + Some(true), + None, + Some(false), + Some(true), + ]); + let c = and(&a, &b); + + let expected = BooleanArray::from(vec![ + None, + None, + None, + None, + Some(false), + Some(false), + None, + Some(false), + Some(true), + ]); + + assert_eq!(c, expected); +} + +#[test] +fn array_and_sliced_same_offset() { + let a = BooleanArray::from_slice(vec![ + false, false, false, false, false, false, false, false, false, false, true, true, + ]); + let b = BooleanArray::from_slice(vec![ + false, false, false, false, false, false, false, false, false, true, false, true, + ]); + + let a = a.sliced(8, 4); + let b = b.sliced(8, 4); + let c = and(&a, &b); + + let expected = BooleanArray::from_slice(vec![false, false, false, true]); + + assert_eq!(expected, c); +} + +#[test] +fn array_and_sliced_same_offset_mod8() { + let a = BooleanArray::from_slice(vec![ + false, false, true, true, false, false, false, false, false, false, false, false, + ]); + let b = BooleanArray::from_slice(vec![ + false, false, false, false, false, false, false, false, false, true, false, true, + ]); + + let a = a.sliced(0, 4); + let b = b.sliced(8, 4); + + let c = and(&a, &b); + + let expected = BooleanArray::from_slice(vec![false, false, false, true]); + + assert_eq!(expected, c); +} + +#[test] +fn array_and_sliced_offset1() { + let a = BooleanArray::from_slice(vec![ + false, false, false, false, false, false, false, false, false, false, true, true, + ]); + let b = BooleanArray::from_slice(vec![false, true, false, true]); + + let a = a.sliced(8, 4); + + let c = and(&a, &b); + + let expected = BooleanArray::from_slice(vec![false, false, false, true]); + + assert_eq!(expected, c); +} + +#[test] +fn array_and_sliced_offset2() { + let a = BooleanArray::from_slice(vec![false, false, true, true]); + let b = BooleanArray::from_slice(vec![ + false, false, false, false, false, false, false, false, false, true, false, true, + ]); + + let b = b.sliced(8, 4); + + let c = and(&a, &b); + + let expected = BooleanArray::from_slice(vec![false, false, false, true]); + + assert_eq!(expected, c); +} + +#[test] +fn array_and_validity_offset() { + let a = BooleanArray::from(vec![None, Some(false), Some(true), None, Some(true)]); + let a = a.sliced(1, 4); + let a = a.as_any().downcast_ref::().unwrap(); + + let b = BooleanArray::from(vec![ + None, + None, + Some(true), + Some(false), + Some(true), + Some(true), + ]); + + let b = b.sliced(2, 4); + let b = b.as_any().downcast_ref::().unwrap(); + + let c = and(a, b); + + let expected = BooleanArray::from(vec![Some(false), Some(false), None, Some(true)]); + + assert_eq!(expected, c); +} + +#[test] +fn test_nonnull_array_is_null() { + let a = Int32Array::from_slice([1, 2, 3, 4]); + + let res = is_null(&a); + + let expected = BooleanArray::from_slice(vec![false, false, false, false]); + + assert_eq!(expected, res); +} + +#[test] +fn test_nonnull_array_with_offset_is_null() { + let a = Int32Array::from_slice(vec![1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1]); + let a = a.sliced(8, 4); + + let res = is_null(&a); + + let expected = BooleanArray::from_slice(vec![false, false, false, false]); + + assert_eq!(expected, res); +} + +#[test] +fn test_nonnull_array_is_not_null() { + let a = Int32Array::from_slice([1, 2, 3, 4]); + + let res = is_not_null(&a); + + let expected = BooleanArray::from_slice(vec![true, true, true, true]); + + assert_eq!(expected, res); +} + +#[test] +fn test_nonnull_array_with_offset_is_not_null() { + let a = Int32Array::from_slice([1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1]); + let a = a.sliced(8, 4); + + let res = is_not_null(&a); + + let expected = BooleanArray::from_slice([true, true, true, true]); + + assert_eq!(expected, res); +} + +#[test] +fn test_nullable_array_is_null() { + let a = Int32Array::from(vec![Some(1), None, Some(3), None]); + + let res = is_null(&a); + + let expected = BooleanArray::from_slice(vec![false, true, false, true]); + + assert_eq!(expected, res); +} + +#[test] +fn test_nullable_array_with_offset_is_null() { + let a = Int32Array::from(vec![ + None, + None, + None, + None, + None, + None, + None, + None, + // offset 8, previous None values are skipped by the slice + Some(1), + None, + Some(2), + None, + Some(3), + Some(4), + None, + None, + ]); + let a = a.sliced(8, 4); + + let res = is_null(&a); + + let expected = BooleanArray::from_slice(vec![false, true, false, true]); + + assert_eq!(expected, res); +} + +#[test] +fn test_nullable_array_is_not_null() { + let a = Int32Array::from(vec![Some(1), None, Some(3), None]); + + let res = is_not_null(&a); + + let expected = BooleanArray::from_slice(vec![true, false, true, false]); + + assert_eq!(expected, res); +} + +#[test] +fn test_nullable_array_with_offset_is_not_null() { + let a = Int32Array::from(vec![ + None, + None, + None, + None, + None, + None, + None, + None, + // offset 8, previous None values are skipped by the slice + Some(1), + None, + Some(2), + None, + Some(3), + Some(4), + None, + None, + ]); + let a = a.sliced(8, 4); + + let res = is_not_null(&a); + + let expected = BooleanArray::from_slice(vec![true, false, true, false]); + + assert_eq!(expected, res); +} + +#[test] +fn array_and_scalar() { + let array = BooleanArray::from_slice([false, false, true, true]); + + let scalar = BooleanScalar::new(Some(true)); + let real = and_scalar(&array, &scalar); + + let expected = BooleanArray::from_slice([false, false, true, true]); + assert_eq!(real, expected); + + let scalar = BooleanScalar::new(Some(false)); + let real = and_scalar(&array, &scalar); + + let expected = BooleanArray::from_slice([false, false, false, false]); + + assert_eq!(real, expected); +} + +#[test] +fn array_and_scalar_validity() { + let array = BooleanArray::from(&[None, Some(false), Some(true)]); + + let scalar = BooleanScalar::new(Some(true)); + let real = and_scalar(&array, &scalar); + + let expected = BooleanArray::from(&[None, Some(false), Some(true)]); + assert_eq!(real, expected); + + let scalar = BooleanScalar::new(None); + let real = and_scalar(&array, &scalar); + + let expected = BooleanArray::from(&[None; 3]); + assert_eq!(real, expected); + + let array = BooleanArray::from_slice([true, false, true]); + let real = and_scalar(&array, &scalar); + + let expected = BooleanArray::from(&[None; 3]); + assert_eq!(real, expected); +} + +#[test] +fn array_or_scalar() { + let array = BooleanArray::from_slice([false, false, true, true]); + + let scalar = BooleanScalar::new(Some(true)); + let real = or_scalar(&array, &scalar); + + let expected = BooleanArray::from_slice([true, true, true, true]); + assert_eq!(real, expected); + + let scalar = BooleanScalar::new(Some(false)); + let real = or_scalar(&array, &scalar); + + let expected = BooleanArray::from_slice([false, false, true, true]); + assert_eq!(real, expected); +} + +#[test] +fn array_or_scalar_validity() { + let array = BooleanArray::from(&[None, Some(false), Some(true)]); + + let scalar = BooleanScalar::new(Some(true)); + let real = or_scalar(&array, &scalar); + + let expected = BooleanArray::from(&[None, Some(true), Some(true)]); + assert_eq!(real, expected); + + let scalar = BooleanScalar::new(None); + let real = or_scalar(&array, &scalar); + + let expected = BooleanArray::from(&[None; 3]); + assert_eq!(real, expected); + + let array = BooleanArray::from_slice([true, false, true]); + let real = and_scalar(&array, &scalar); + + let expected = BooleanArray::from(&[None; 3]); + assert_eq!(real, expected); +} + +#[test] +fn test_any_all() { + let array = BooleanArray::from(&[None, Some(false), Some(true)]); + assert!(any(&array)); + assert!(!all(&array)); + let array = BooleanArray::from(&[None, Some(false), Some(false)]); + assert!(!any(&array)); + assert!(!all(&array)); + let array = BooleanArray::from(&[None, Some(true), Some(true)]); + assert!(any(&array)); + assert!(all(&array)); + let array = BooleanArray::from_iter(std::iter::repeat(false).take(10).map(Some)); + assert!(!any(&array)); + assert!(!all(&array)); + let array = BooleanArray::from_iter(std::iter::repeat(true).take(10).map(Some)); + assert!(any(&array)); + assert!(all(&array)); + let array = BooleanArray::from_iter([true, false, true, true].map(Some)); + assert!(any(&array)); + assert!(!all(&array)); + let array = BooleanArray::from(&[Some(true)]); + assert!(any(&array)); + assert!(all(&array)); + let array = BooleanArray::from(&[Some(false)]); + assert!(!any(&array)); + assert!(!all(&array)); + let array = BooleanArray::from(&[]); + assert!(!any(&array)); + assert!(all(&array)); +} diff --git a/crates/polars/tests/it/arrow/compute/boolean_kleene.rs b/crates/polars/tests/it/arrow/compute/boolean_kleene.rs new file mode 100644 index 0000000000000..515490796d38f --- /dev/null +++ b/crates/polars/tests/it/arrow/compute/boolean_kleene.rs @@ -0,0 +1,223 @@ +use arrow::array::BooleanArray; +use arrow::compute::boolean_kleene::*; +use arrow::scalar::BooleanScalar; + +#[test] +fn and_generic() { + let lhs = BooleanArray::from(&[ + None, + None, + None, + Some(false), + Some(false), + Some(false), + Some(true), + Some(true), + Some(true), + ]); + let rhs = BooleanArray::from(&[ + None, + Some(false), + Some(true), + None, + Some(false), + Some(true), + None, + Some(false), + Some(true), + ]); + let c = and(&lhs, &rhs); + + let expected = BooleanArray::from(&[ + None, + Some(false), + None, + Some(false), + Some(false), + Some(false), + None, + Some(false), + Some(true), + ]); + + assert_eq!(c, expected); +} + +#[test] +fn or_generic() { + let a = BooleanArray::from(&[ + None, + None, + None, + Some(false), + Some(false), + Some(false), + Some(true), + Some(true), + Some(true), + ]); + let b = BooleanArray::from(&[ + None, + Some(false), + Some(true), + None, + Some(false), + Some(true), + None, + Some(false), + Some(true), + ]); + let c = or(&a, &b); + + let expected = BooleanArray::from(&[ + None, + None, + Some(true), + None, + Some(false), + Some(true), + Some(true), + Some(true), + Some(true), + ]); + + assert_eq!(c, expected); +} + +#[test] +fn or_right_nulls() { + let a = BooleanArray::from_slice([false, false, false, true, true, true]); + + let b = BooleanArray::from(&[Some(true), Some(false), None, Some(true), Some(false), None]); + + let c = or(&a, &b); + + let expected = BooleanArray::from(&[ + Some(true), + Some(false), + None, + Some(true), + Some(true), + Some(true), + ]); + + assert_eq!(c, expected); +} + +#[test] +fn or_left_nulls() { + let a = BooleanArray::from(vec![ + Some(true), + Some(false), + None, + Some(true), + Some(false), + None, + ]); + + let b = BooleanArray::from_slice([false, false, false, true, true, true]); + + let c = or(&a, &b); + + let expected = BooleanArray::from(vec![ + Some(true), + Some(false), + None, + Some(true), + Some(true), + Some(true), + ]); + + assert_eq!(c, expected); +} + +#[test] +fn array_and_true() { + let array = BooleanArray::from(&[Some(true), Some(false), None, Some(true), Some(false), None]); + + let scalar = BooleanScalar::new(Some(true)); + let result = and_scalar(&array, &scalar); + + // Should be same as argument array if scalar is true. + assert_eq!(result, array); +} + +#[test] +fn array_and_false() { + let array = BooleanArray::from(&[Some(true), Some(false), None, Some(true), Some(false), None]); + + let scalar = BooleanScalar::new(Some(false)); + let result = and_scalar(&array, &scalar); + + let expected = BooleanArray::from(&[ + Some(false), + Some(false), + Some(false), + Some(false), + Some(false), + Some(false), + ]); + + assert_eq!(result, expected); +} + +#[test] +fn array_and_none() { + let array = BooleanArray::from(&[Some(true), Some(false), None, Some(true), Some(false), None]); + + let scalar = BooleanScalar::new(None); + let result = and_scalar(&array, &scalar); + + let expected = BooleanArray::from(&[None, Some(false), None, None, Some(false), None]); + + assert_eq!(result, expected); +} + +#[test] +fn array_or_true() { + let array = BooleanArray::from(&[Some(true), Some(false), None, Some(true), Some(false), None]); + + let scalar = BooleanScalar::new(Some(true)); + let result = or_scalar(&array, &scalar); + + let expected = BooleanArray::from(&[ + Some(true), + Some(true), + Some(true), + Some(true), + Some(true), + Some(true), + ]); + + assert_eq!(result, expected); +} + +#[test] +fn array_or_false() { + let array = BooleanArray::from(&[Some(true), Some(false), None, Some(true), Some(false), None]); + + let scalar = BooleanScalar::new(Some(false)); + let result = or_scalar(&array, &scalar); + + // Should be same as argument array if scalar is false. + assert_eq!(result, array); +} + +#[test] +fn array_or_none() { + let array = BooleanArray::from(&[Some(true), Some(false), None, Some(true), Some(false), None]); + + let scalar = BooleanScalar::new(None); + let result = or_scalar(&array, &scalar); + + let expected = BooleanArray::from(&[Some(true), None, None, Some(true), None, None]); + + assert_eq!(result, expected); +} + +#[test] +fn array_empty() { + let array = BooleanArray::from(&[]); + assert_eq!(any(&array), Some(false)); + assert_eq!(all(&array), Some(true)); +} diff --git a/crates/polars/tests/it/arrow/compute/if_then_else.rs b/crates/polars/tests/it/arrow/compute/if_then_else.rs new file mode 100644 index 0000000000000..e203d831c39f2 --- /dev/null +++ b/crates/polars/tests/it/arrow/compute/if_then_else.rs @@ -0,0 +1,42 @@ +use arrow::array::*; +use arrow::compute::if_then_else::if_then_else; +use polars_error::PolarsResult; + +#[test] +fn basics() -> PolarsResult<()> { + let lhs = Int32Array::from_slice([1, 2, 3]); + let rhs = Int32Array::from_slice([4, 5, 6]); + let predicate = BooleanArray::from_slice(vec![true, false, true]); + let c = if_then_else(&predicate, &lhs, &rhs)?; + + let expected = Int32Array::from_slice([1, 5, 3]); + + assert_eq!(expected, c.as_ref()); + Ok(()) +} + +#[test] +fn basics_nulls() -> PolarsResult<()> { + let lhs = Int32Array::from(&[Some(1), None, None]); + let rhs = Int32Array::from(&[None, Some(5), Some(6)]); + let predicate = BooleanArray::from_slice(vec![true, false, true]); + let c = if_then_else(&predicate, &lhs, &rhs)?; + + let expected = Int32Array::from(&[Some(1), Some(5), None]); + + assert_eq!(expected, c.as_ref()); + Ok(()) +} + +#[test] +fn basics_nulls_pred() -> PolarsResult<()> { + let lhs = Int32Array::from_slice([1, 2, 3]); + let rhs = Int32Array::from_slice([4, 5, 6]); + let predicate = BooleanArray::from(&[Some(true), None, Some(false)]); + let result = if_then_else(&predicate, &lhs, &rhs)?; + + let expected = Int32Array::from(&[Some(1), None, Some(6)]); + + assert_eq!(expected, result.as_ref()); + Ok(()) +} diff --git a/crates/polars/tests/it/arrow/compute/mod.rs b/crates/polars/tests/it/arrow/compute/mod.rs new file mode 100644 index 0000000000000..95126a4a3a547 --- /dev/null +++ b/crates/polars/tests/it/arrow/compute/mod.rs @@ -0,0 +1,12 @@ +#[cfg(feature = "compute_aggregate")] +mod aggregate; +#[cfg(feature = "compute_bitwise")] +mod bitwise; +#[cfg(feature = "compute_boolean")] +mod boolean; +#[cfg(feature = "compute_boolean_kleene")] +mod boolean_kleene; +#[cfg(feature = "compute_if_then_else")] +mod if_then_else; + +mod arity_assign; diff --git a/crates/polars-arrow/tests/it/ffi/data.rs b/crates/polars/tests/it/arrow/ffi/data.rs similarity index 94% rename from crates/polars-arrow/tests/it/ffi/data.rs rename to crates/polars/tests/it/arrow/ffi/data.rs index 1b5fc86922c04..bb798a1bc4fc0 100644 --- a/crates/polars-arrow/tests/it/ffi/data.rs +++ b/crates/polars/tests/it/arrow/ffi/data.rs @@ -1,6 +1,6 @@ -use polars_arrow::array::*; -use polars_arrow::datatypes::Field; -use polars_arrow::ffi; +use arrow::array::*; +use arrow::datatypes::Field; +use arrow::ffi; use polars_error::PolarsResult; fn _test_round_trip(array: Box, expected: Box) -> PolarsResult<()> { diff --git a/crates/polars/tests/it/arrow/ffi/mod.rs b/crates/polars/tests/it/arrow/ffi/mod.rs new file mode 100644 index 0000000000000..1ca8fa75c4008 --- /dev/null +++ b/crates/polars/tests/it/arrow/ffi/mod.rs @@ -0,0 +1,3 @@ +mod data; + +mod stream; diff --git a/crates/polars/tests/it/arrow/ffi/stream.rs b/crates/polars/tests/it/arrow/ffi/stream.rs new file mode 100644 index 0000000000000..f949fdf4c88ec --- /dev/null +++ b/crates/polars/tests/it/arrow/ffi/stream.rs @@ -0,0 +1,44 @@ +use arrow::array::*; +use arrow::datatypes::Field; +use arrow::ffi; +use polars_error::{PolarsError, PolarsResult}; + +fn _test_round_trip(arrays: Vec>) -> PolarsResult<()> { + let field = Field::new("a", arrays[0].data_type().clone(), true); + let iter = Box::new(arrays.clone().into_iter().map(Ok)) as _; + + let mut stream = Box::new(ffi::ArrowArrayStream::empty()); + + *stream = ffi::export_iterator(iter, field.clone()); + + // import + let mut stream = unsafe { ffi::ArrowArrayStreamReader::try_new(stream)? }; + + let mut produced_arrays: Vec> = vec![]; + while let Some(array) = unsafe { stream.next() } { + produced_arrays.push(array?); + } + + assert_eq!(produced_arrays, arrays); + assert_eq!(stream.field(), &field); + Ok(()) +} + +#[test] +fn round_trip() -> PolarsResult<()> { + let array = Int32Array::from(&[Some(2), None, Some(1), None]); + let array: Box = Box::new(array); + + _test_round_trip(vec![array.clone(), array.clone(), array]) +} + +#[test] +fn stream_reader_try_new_invalid_argument_error_on_released_stream() { + let released_stream = Box::new(ffi::ArrowArrayStream::empty()); + let reader = unsafe { ffi::ArrowArrayStreamReader::try_new(released_stream) }; + // poor man's assert_matches: + match reader { + Err(PolarsError::InvalidOperation(_)) => {}, + _ => panic!("ArrowArrayStreamReader::try_new did not return an InvalidArgumentError"), + } +} diff --git a/crates/polars-arrow/tests/it/io/ipc/mod.rs b/crates/polars/tests/it/arrow/io/ipc/mod.rs similarity index 89% rename from crates/polars-arrow/tests/it/io/ipc/mod.rs rename to crates/polars/tests/it/arrow/io/ipc/mod.rs index 202eaf0cdfb2e..c55b346e9702e 100644 --- a/crates/polars-arrow/tests/it/io/ipc/mod.rs +++ b/crates/polars/tests/it/arrow/io/ipc/mod.rs @@ -1,12 +1,12 @@ use std::io::Cursor; use std::sync::Arc; -use polars_arrow::array::*; -use polars_arrow::chunk::Chunk; -use polars_arrow::datatypes::{ArrowSchema, ArrowSchemaRef, Field}; -use polars_arrow::io::ipc::read::{read_file_metadata, FileReader}; -use polars_arrow::io::ipc::write::*; -use polars_arrow::io::ipc::IpcField; +use arrow::array::*; +use arrow::chunk::Chunk; +use arrow::datatypes::{ArrowSchema, ArrowSchemaRef, Field}; +use arrow::io::ipc::read::{read_file_metadata, FileReader}; +use arrow::io::ipc::write::*; +use arrow::io::ipc::IpcField; use polars_error::*; pub(crate) fn write( diff --git a/crates/polars-arrow/tests/it/io/mod.rs b/crates/polars/tests/it/arrow/io/mod.rs similarity index 100% rename from crates/polars-arrow/tests/it/io/mod.rs rename to crates/polars/tests/it/arrow/io/mod.rs diff --git a/crates/polars/tests/it/arrow/mod.rs b/crates/polars/tests/it/arrow/mod.rs new file mode 100644 index 0000000000000..f9f3ef3d2ac99 --- /dev/null +++ b/crates/polars/tests/it/arrow/mod.rs @@ -0,0 +1,12 @@ +mod ffi; +#[cfg(feature = "io_ipc_compression")] +mod io; + +mod scalar; + +mod array; +mod bitmap; + +mod buffer; + +mod compute; diff --git a/crates/polars/tests/it/arrow/scalar/binary.rs b/crates/polars/tests/it/arrow/scalar/binary.rs new file mode 100644 index 0000000000000..d1b3e984d379c --- /dev/null +++ b/crates/polars/tests/it/arrow/scalar/binary.rs @@ -0,0 +1,31 @@ +use arrow::datatypes::ArrowDataType; +use arrow::scalar::{BinaryScalar, Scalar}; + +#[allow(clippy::eq_op)] +#[test] +fn equal() { + let a = BinaryScalar::::from(Some("a")); + let b = BinaryScalar::::from(None::<&str>); + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); + let b = BinaryScalar::::from(Some("b")); + assert!(a != b); + assert_eq!(b, b); +} + +#[test] +fn basics() { + let a = BinaryScalar::::from(Some("a")); + + assert_eq!(a.value(), Some(b"a".as_ref())); + assert_eq!(a.data_type(), &ArrowDataType::Binary); + assert!(a.is_valid()); + + let a = BinaryScalar::::from(None::<&str>); + + assert_eq!(a.data_type(), &ArrowDataType::LargeBinary); + assert!(!a.is_valid()); + + let _: &dyn std::any::Any = a.as_any(); +} diff --git a/crates/polars/tests/it/arrow/scalar/boolean.rs b/crates/polars/tests/it/arrow/scalar/boolean.rs new file mode 100644 index 0000000000000..7c400b0fde3ef --- /dev/null +++ b/crates/polars/tests/it/arrow/scalar/boolean.rs @@ -0,0 +1,26 @@ +use arrow::datatypes::ArrowDataType; +use arrow::scalar::{BooleanScalar, Scalar}; + +#[allow(clippy::eq_op)] +#[test] +fn equal() { + let a = BooleanScalar::from(Some(true)); + let b = BooleanScalar::from(None); + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); + let b = BooleanScalar::from(Some(false)); + assert!(a != b); + assert_eq!(b, b); +} + +#[test] +fn basics() { + let a = BooleanScalar::new(Some(true)); + + assert_eq!(a.value(), Some(true)); + assert_eq!(a.data_type(), &ArrowDataType::Boolean); + assert!(a.is_valid()); + + let _: &dyn std::any::Any = a.as_any(); +} diff --git a/crates/polars/tests/it/arrow/scalar/fixed_size_binary.rs b/crates/polars/tests/it/arrow/scalar/fixed_size_binary.rs new file mode 100644 index 0000000000000..c83bc4d697497 --- /dev/null +++ b/crates/polars/tests/it/arrow/scalar/fixed_size_binary.rs @@ -0,0 +1,26 @@ +use arrow::datatypes::ArrowDataType; +use arrow::scalar::{FixedSizeBinaryScalar, Scalar}; + +#[allow(clippy::eq_op)] +#[test] +fn equal() { + let a = FixedSizeBinaryScalar::new(ArrowDataType::FixedSizeBinary(1), Some("a")); + let b = FixedSizeBinaryScalar::new(ArrowDataType::FixedSizeBinary(1), None::<&str>); + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); + let b = FixedSizeBinaryScalar::new(ArrowDataType::FixedSizeBinary(1), Some("b")); + assert!(a != b); + assert_eq!(b, b); +} + +#[test] +fn basics() { + let a = FixedSizeBinaryScalar::new(ArrowDataType::FixedSizeBinary(1), Some("a")); + + assert_eq!(a.value(), Some(b"a".as_ref())); + assert_eq!(a.data_type(), &ArrowDataType::FixedSizeBinary(1)); + assert!(a.is_valid()); + + let _: &dyn std::any::Any = a.as_any(); +} diff --git a/crates/polars/tests/it/arrow/scalar/fixed_size_list.rs b/crates/polars/tests/it/arrow/scalar/fixed_size_list.rs new file mode 100644 index 0000000000000..2aa6f45bbd744 --- /dev/null +++ b/crates/polars/tests/it/arrow/scalar/fixed_size_list.rs @@ -0,0 +1,43 @@ +use arrow::array::BooleanArray; +use arrow::datatypes::{ArrowDataType, Field}; +use arrow::scalar::{FixedSizeListScalar, Scalar}; + +#[allow(clippy::eq_op)] +#[test] +fn equal() { + let dt = + ArrowDataType::FixedSizeList(Box::new(Field::new("a", ArrowDataType::Boolean, true)), 2); + let a = FixedSizeListScalar::new( + dt.clone(), + Some(BooleanArray::from_slice([true, false]).boxed()), + ); + + let b = FixedSizeListScalar::new(dt.clone(), None); + + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); + + let b = FixedSizeListScalar::new(dt, Some(BooleanArray::from_slice([true, true]).boxed())); + assert!(a != b); + assert_eq!(b, b); +} + +#[test] +fn basics() { + let dt = + ArrowDataType::FixedSizeList(Box::new(Field::new("a", ArrowDataType::Boolean, true)), 2); + let a = FixedSizeListScalar::new( + dt.clone(), + Some(BooleanArray::from_slice([true, false]).boxed()), + ); + + assert_eq!( + BooleanArray::from_slice([true, false]), + a.values().unwrap().as_ref() + ); + assert_eq!(a.data_type(), &dt); + assert!(a.is_valid()); + + let _: &dyn std::any::Any = a.as_any(); +} diff --git a/crates/polars/tests/it/arrow/scalar/list.rs b/crates/polars/tests/it/arrow/scalar/list.rs new file mode 100644 index 0000000000000..7cd2938237c91 --- /dev/null +++ b/crates/polars/tests/it/arrow/scalar/list.rs @@ -0,0 +1,35 @@ +use arrow::array::BooleanArray; +use arrow::datatypes::{ArrowDataType, Field}; +use arrow::scalar::{ListScalar, Scalar}; + +#[allow(clippy::eq_op)] +#[test] +fn equal() { + let dt = ArrowDataType::List(Box::new(Field::new("a", ArrowDataType::Boolean, true))); + let a = ListScalar::::new( + dt.clone(), + Some(BooleanArray::from_slice([true, false]).boxed()), + ); + let b = ListScalar::::new(dt.clone(), None); + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); + let b = ListScalar::::new(dt, Some(BooleanArray::from_slice([true, true]).boxed())); + assert!(a != b); + assert_eq!(b, b); +} + +#[test] +fn basics() { + let dt = ArrowDataType::List(Box::new(Field::new("a", ArrowDataType::Boolean, true))); + let a = ListScalar::::new( + dt.clone(), + Some(BooleanArray::from_slice([true, false]).boxed()), + ); + + assert_eq!(BooleanArray::from_slice([true, false]), a.values().as_ref()); + assert_eq!(a.data_type(), &dt); + assert!(a.is_valid()); + + let _: &dyn std::any::Any = a.as_any(); +} diff --git a/crates/polars/tests/it/arrow/scalar/map.rs b/crates/polars/tests/it/arrow/scalar/map.rs new file mode 100644 index 0000000000000..e9f0ede0784f9 --- /dev/null +++ b/crates/polars/tests/it/arrow/scalar/map.rs @@ -0,0 +1,66 @@ +use arrow::array::{BooleanArray, StructArray, Utf8Array}; +use arrow::datatypes::{ArrowDataType, Field}; +use arrow::scalar::{MapScalar, Scalar}; + +#[allow(clippy::eq_op)] +#[test] +fn equal() { + let kv_dt = ArrowDataType::Struct(vec![ + Field::new("key", ArrowDataType::Utf8, false), + Field::new("value", ArrowDataType::Boolean, true), + ]); + let kv_array1 = StructArray::try_new( + kv_dt.clone(), + vec![ + Utf8Array::::from([Some("k1"), Some("k2")]).boxed(), + BooleanArray::from_slice([true, false]).boxed(), + ], + None, + ) + .unwrap(); + let kv_array2 = StructArray::try_new( + kv_dt.clone(), + vec![ + Utf8Array::::from([Some("k1"), Some("k3")]).boxed(), + BooleanArray::from_slice([true, true]).boxed(), + ], + None, + ) + .unwrap(); + + let dt = ArrowDataType::Map(Box::new(Field::new("entries", kv_dt, true)), false); + let a = MapScalar::new(dt.clone(), Some(Box::new(kv_array1))); + let b = MapScalar::new(dt.clone(), None); + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); + let b = MapScalar::new(dt, Some(Box::new(kv_array2))); + assert!(a != b); + assert_eq!(b, b); +} + +#[test] +fn basics() { + let kv_dt = ArrowDataType::Struct(vec![ + Field::new("key", ArrowDataType::Utf8, false), + Field::new("value", ArrowDataType::Boolean, true), + ]); + let kv_array = StructArray::try_new( + kv_dt.clone(), + vec![ + Utf8Array::::from([Some("k1"), Some("k2")]).boxed(), + BooleanArray::from_slice([true, false]).boxed(), + ], + None, + ) + .unwrap(); + + let dt = ArrowDataType::Map(Box::new(Field::new("entries", kv_dt, true)), false); + let a = MapScalar::new(dt.clone(), Some(Box::new(kv_array.clone()))); + + assert_eq!(kv_array, a.values().as_ref()); + assert_eq!(a.data_type(), &dt); + assert!(a.is_valid()); + + let _: &dyn std::any::Any = a.as_any(); +} diff --git a/crates/polars/tests/it/arrow/scalar/mod.rs b/crates/polars/tests/it/arrow/scalar/mod.rs new file mode 100644 index 0000000000000..0c1ef990b8295 --- /dev/null +++ b/crates/polars/tests/it/arrow/scalar/mod.rs @@ -0,0 +1,16 @@ +mod binary; +mod boolean; +mod fixed_size_binary; +mod fixed_size_list; +mod list; +mod map; +mod null; +mod primitive; +mod struct_; +mod utf8; + +// check that `PartialEq` can be derived +#[derive(PartialEq)] +struct A { + array: Box, +} diff --git a/crates/polars/tests/it/arrow/scalar/null.rs b/crates/polars/tests/it/arrow/scalar/null.rs new file mode 100644 index 0000000000000..3ceaf69f83b60 --- /dev/null +++ b/crates/polars/tests/it/arrow/scalar/null.rs @@ -0,0 +1,19 @@ +use arrow::datatypes::ArrowDataType; +use arrow::scalar::{NullScalar, Scalar}; + +#[allow(clippy::eq_op)] +#[test] +fn equal() { + let a = NullScalar::new(); + assert_eq!(a, a); +} + +#[test] +fn basics() { + let a = NullScalar::default(); + + assert_eq!(a.data_type(), &ArrowDataType::Null); + assert!(!a.is_valid()); + + let _: &dyn std::any::Any = a.as_any(); +} diff --git a/crates/polars/tests/it/arrow/scalar/primitive.rs b/crates/polars/tests/it/arrow/scalar/primitive.rs new file mode 100644 index 0000000000000..954a80147833c --- /dev/null +++ b/crates/polars/tests/it/arrow/scalar/primitive.rs @@ -0,0 +1,36 @@ +use arrow::datatypes::ArrowDataType; +use arrow::scalar::{PrimitiveScalar, Scalar}; + +#[allow(clippy::eq_op)] +#[test] +fn equal() { + let a = PrimitiveScalar::from(Some(2i32)); + let b = PrimitiveScalar::::from(None); + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); + let b = PrimitiveScalar::::from(Some(1i32)); + assert!(a != b); + assert_eq!(b, b); +} + +#[test] +fn basics() { + let a = PrimitiveScalar::from(Some(2i32)); + + assert_eq!(a.value(), &Some(2i32)); + assert_eq!(a.data_type(), &ArrowDataType::Int32); + + let a = a.to(ArrowDataType::Date32); + assert_eq!(a.data_type(), &ArrowDataType::Date32); + + let a = PrimitiveScalar::::from(None); + + assert_eq!(a.data_type(), &ArrowDataType::Int32); + assert!(!a.is_valid()); + + let a = a.to(ArrowDataType::Date32); + assert_eq!(a.data_type(), &ArrowDataType::Date32); + + let _: &dyn std::any::Any = a.as_any(); +} diff --git a/crates/polars/tests/it/arrow/scalar/struct_.rs b/crates/polars/tests/it/arrow/scalar/struct_.rs new file mode 100644 index 0000000000000..23461bb26568a --- /dev/null +++ b/crates/polars/tests/it/arrow/scalar/struct_.rs @@ -0,0 +1,41 @@ +use arrow::datatypes::{ArrowDataType, Field}; +use arrow::scalar::{BooleanScalar, Scalar, StructScalar}; + +#[allow(clippy::eq_op)] +#[test] +fn equal() { + let dt = ArrowDataType::Struct(vec![Field::new("a", ArrowDataType::Boolean, true)]); + let a = StructScalar::new( + dt.clone(), + Some(vec![ + Box::new(BooleanScalar::from(Some(true))) as Box + ]), + ); + let b = StructScalar::new(dt.clone(), None); + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); + let b = StructScalar::new( + dt, + Some(vec![ + Box::new(BooleanScalar::from(Some(false))) as Box + ]), + ); + assert!(a != b); + assert_eq!(b, b); +} + +#[test] +fn basics() { + let dt = ArrowDataType::Struct(vec![Field::new("a", ArrowDataType::Boolean, true)]); + + let values = vec![Box::new(BooleanScalar::from(Some(true))) as Box]; + + let a = StructScalar::new(dt.clone(), Some(values.clone())); + + assert_eq!(a.values(), &values); + assert_eq!(a.data_type(), &dt); + assert!(a.is_valid()); + + let _: &dyn std::any::Any = a.as_any(); +} diff --git a/crates/polars/tests/it/arrow/scalar/utf8.rs b/crates/polars/tests/it/arrow/scalar/utf8.rs new file mode 100644 index 0000000000000..bd7c6449d89c6 --- /dev/null +++ b/crates/polars/tests/it/arrow/scalar/utf8.rs @@ -0,0 +1,31 @@ +use arrow::datatypes::ArrowDataType; +use arrow::scalar::{Scalar, Utf8Scalar}; + +#[allow(clippy::eq_op)] +#[test] +fn equal() { + let a = Utf8Scalar::::from(Some("a")); + let b = Utf8Scalar::::from(None::<&str>); + assert_eq!(a, a); + assert_eq!(b, b); + assert!(a != b); + let b = Utf8Scalar::::from(Some("b")); + assert!(a != b); + assert_eq!(b, b); +} + +#[test] +fn basics() { + let a = Utf8Scalar::::from(Some("a")); + + assert_eq!(a.value(), Some("a")); + assert_eq!(a.data_type(), &ArrowDataType::Utf8); + assert!(a.is_valid()); + + let a = Utf8Scalar::::from(None::<&str>); + + assert_eq!(a.data_type(), &ArrowDataType::LargeUtf8); + assert!(!a.is_valid()); + + let _: &dyn std::any::Any = a.as_any(); +} diff --git a/crates/polars/tests/it/chunks/mod.rs b/crates/polars/tests/it/chunks/mod.rs new file mode 100644 index 0000000000000..ab7fe5c8ec35b --- /dev/null +++ b/crates/polars/tests/it/chunks/mod.rs @@ -0,0 +1,2 @@ +#[cfg(feature = "parquet")] +mod parquet; diff --git a/crates/polars/tests/it/chunks/parquet.rs b/crates/polars/tests/it/chunks/parquet.rs new file mode 100644 index 0000000000000..26c37566845ab --- /dev/null +++ b/crates/polars/tests/it/chunks/parquet.rs @@ -0,0 +1,38 @@ +use std::io::{Seek, SeekFrom}; + +use polars::prelude::*; + +#[test] +fn test_cast_join_14872() { + let df1 = df![ + "ints" => [1] + ] + .unwrap(); + + let mut df2 = df![ + "ints" => [0, 1], + "strings" => vec![Series::new("", ["a"]); 2], + ] + .unwrap(); + + let mut buf = std::io::Cursor::new(vec![]); + ParquetWriter::new(&mut buf) + .with_row_group_size(Some(1)) + .finish(&mut df2) + .unwrap(); + + let _ = buf.seek(SeekFrom::Start(0)); + let df2 = ParquetReader::new(buf).finish().unwrap(); + + let out = df1 + .join(&df2, ["ints"], ["ints"], JoinArgs::new(JoinType::Left)) + .unwrap(); + + let expected = df![ + "ints" => [1], + "strings" => vec![Series::new("", ["a"]); 1], + ] + .unwrap(); + + assert!(expected.equals(&out)); +} diff --git a/crates/polars/tests/it/core/joins.rs b/crates/polars/tests/it/core/joins.rs index 0bd00587fbe2c..212de7960562d 100644 --- a/crates/polars/tests/it/core/joins.rs +++ b/crates/polars/tests/it/core/joins.rs @@ -1,6 +1,6 @@ use polars_core::utils::{accumulate_dataframes_vertical, split_df}; #[cfg(feature = "dtype-categorical")] -use polars_core::{disable_string_cache, StringCacheHolder, SINGLE_LOCK}; +use polars_core::{disable_string_cache, SINGLE_LOCK}; use super::*; diff --git a/crates/polars/tests/it/core/pivot.rs b/crates/polars/tests/it/core/pivot.rs index ce1bee1785575..6f9c996b44cc3 100644 --- a/crates/polars/tests/it/core/pivot.rs +++ b/crates/polars/tests/it/core/pivot.rs @@ -6,29 +6,47 @@ use polars_ops::pivot::{pivot, pivot_stable, PivotAgg}; #[cfg(feature = "dtype-date")] fn test_pivot_date_() -> PolarsResult<()> { let mut df = df![ - "A" => [1, 1, 1, 1, 1, 1, 1, 1], - "B" => [8, 2, 3, 6, 3, 6, 2, 2], - "C" => [1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000] + "index" => [8, 2, 3, 6, 3, 6, 2, 2], + "values1" => [1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000], + "values2" => [1, 1, 1, 1, 1, 1, 1, 1], ]?; - df.try_apply("C", |s| s.cast(&DataType::Date))?; + df.try_apply("values1", |s| s.cast(&DataType::Date))?; - let out = pivot(&df, ["A"], ["B"], ["C"], true, Some(PivotAgg::Count), None)?; + // Test with date as the `columns` input + let out = pivot( + &df, + ["index"], + ["values1"], + Some(["values2"]), + true, + Some(PivotAgg::Count), + None, + )?; let first = 1 as IdxSize; let expected = df![ - "B" => [8i32, 2, 3, 6], + "index" => [8i32, 2, 3, 6], "1972-09-27" => [first, 3, 2, 2] ]?; assert!(out.equals_missing(&expected)); - let mut out = pivot_stable(&df, ["C"], ["B"], ["A"], true, Some(PivotAgg::First), None)?; + // Test with date as the `values` input. + let mut out = pivot_stable( + &df, + ["index"], + ["values2"], + Some(["values1"]), + true, + Some(PivotAgg::First), + None, + )?; out.try_apply("1", |s| { let ca = s.date()?; Ok(ca.to_string("%Y-%d-%m")) })?; let expected = df![ - "B" => [8i32, 2, 3, 6], + "index" => [8i32, 2, 3, 6], "1" => ["1972-27-09", "1972-27-09", "1972-27-09", "1972-27-09"] ]?; assert!(out.equals_missing(&expected)); @@ -38,31 +56,31 @@ fn test_pivot_date_() -> PolarsResult<()> { #[test] fn test_pivot_old() { - let s0 = Series::new("foo", ["A", "A", "B", "B", "C"].as_ref()); - let s1 = Series::new("N", [1, 2, 2, 4, 2].as_ref()); - let s2 = Series::new("bar", ["k", "l", "m", "m", "l"].as_ref()); + let s0 = Series::new("index", ["A", "A", "B", "B", "C"].as_ref()); + let s2 = Series::new("columns", ["k", "l", "m", "m", "l"].as_ref()); + let s1 = Series::new("values", [1, 2, 2, 4, 2].as_ref()); let df = DataFrame::new(vec![s0, s1, s2]).unwrap(); let pvt = pivot( &df, - ["N"], - ["foo"], - ["bar"], + ["index"], + ["columns"], + Some(["values"]), false, Some(PivotAgg::Sum), None, ) .unwrap(); - assert_eq!(pvt.get_column_names(), &["foo", "k", "l", "m"]); + assert_eq!(pvt.get_column_names(), &["index", "k", "l", "m"]); assert_eq!( Vec::from(&pvt.column("m").unwrap().i32().unwrap().sort(false)), &[None, None, Some(6)] ); let pvt = pivot( &df, - ["N"], - ["foo"], - ["bar"], + ["index"], + ["columns"], + Some(["values"]), false, Some(PivotAgg::Min), None, @@ -74,9 +92,9 @@ fn test_pivot_old() { ); let pvt = pivot( &df, - ["N"], - ["foo"], - ["bar"], + ["index"], + ["columns"], + Some(["values"]), false, Some(PivotAgg::Max), None, @@ -88,9 +106,9 @@ fn test_pivot_old() { ); let pvt = pivot( &df, - ["N"], - ["foo"], - ["bar"], + ["index"], + ["columns"], + Some(["values"]), false, Some(PivotAgg::Mean), None, @@ -102,9 +120,9 @@ fn test_pivot_old() { ); let pvt = pivot( &df, - ["N"], - ["foo"], - ["bar"], + ["index"], + ["columns"], + Some(["values"]), false, Some(PivotAgg::Count), None, @@ -120,46 +138,51 @@ fn test_pivot_old() { #[cfg(feature = "dtype-categorical")] fn test_pivot_categorical() -> PolarsResult<()> { let mut df = df![ - "A" => [1, 1, 1, 1, 1, 1, 1, 1], - "B" => [8, 2, 3, 6, 3, 6, 2, 2], - "C" => ["a", "b", "c", "a", "b", "c", "a", "b"] + "index" => [1, 1, 1, 1, 1, 1, 1, 1], + "columns" => ["a", "b", "c", "a", "b", "c", "a", "b"], + "values" => [8, 2, 3, 6, 3, 6, 2, 2], ]?; - df.try_apply("C", |s| { + df.try_apply("columns", |s| { s.cast(&DataType::Categorical(None, Default::default())) })?; - let out = pivot(&df, ["A"], ["B"], ["C"], true, Some(PivotAgg::Count), None)?; - assert_eq!(out.get_column_names(), &["B", "a", "b", "c"]); + let out = pivot( + &df, + ["index"], + ["columns"], + Some(["values"]), + true, + Some(PivotAgg::Count), + None, + )?; + assert_eq!(out.get_column_names(), &["index", "a", "b", "c"]); Ok(()) } #[test] fn test_pivot_new() -> PolarsResult<()> { - let df = df!["A"=> ["foo", "foo", "foo", "foo", "foo", - "bar", "bar", "bar", "bar"], - "B"=> ["one", "one", "one", "two", "two", - "one", "one", "two", "two"], - "C"=> ["small", "large", "large", "small", - "small", "large", "small", "small", "large"], - "breaky"=> ["jam", "egg", "egg", "egg", - "jam", "jam", "potato", "jam", "jam"], - "D"=> [1, 2, 2, 3, 3, 4, 5, 6, 7], - "E"=> [2, 4, 5, 5, 6, 6, 8, 9, 9] + let df = df![ + "index1"=> ["foo", "foo", "foo", "foo", "foo", "bar", "bar", "bar", "bar"], + "index2"=> ["one", "one", "one", "two", "two", "one", "one", "two", "two"], + "cols1"=> ["small", "large", "large", "small", "small", "large", "small", "small", "large"], + "cols2"=> ["jam", "egg", "egg", "egg", "jam", "jam", "potato", "jam", "jam"], + "values1"=> [1, 2, 2, 3, 3, 4, 5, 6, 7], + "values2"=> [2, 4, 5, 5, 6, 6, 8, 9, 9] ]?; let out = (pivot_stable( &df, - ["D"], - ["A", "B"], - ["C"], + ["index1", "index2"], + ["cols1"], + Some(["values1"]), true, Some(PivotAgg::Sum), None, ))?; let expected = df![ - "A" => ["foo", "foo", "bar", "bar"], - "B" => ["one", "two", "one", "two"], + "index1" => ["foo", "foo", "bar", "bar"], + "index2" => ["one", "two", "one", "two"], "large" => [Some(4), None, Some(4), Some(7)], "small" => [1, 6, 5, 6], ]?; @@ -167,16 +190,16 @@ fn test_pivot_new() -> PolarsResult<()> { let out = pivot_stable( &df, - ["D"], - ["A", "B"], - ["C", "breaky"], + ["index1", "index2"], + ["cols1", "cols2"], + Some(["values1"]), true, Some(PivotAgg::Sum), None, )?; let expected = df![ - "A" => ["foo", "foo", "bar", "bar"], - "B" => ["one", "two", "one", "two"], + "index1" => ["foo", "foo", "bar", "bar"], + "index2" => ["one", "two", "one", "two"], "{\"large\",\"egg\"}" => [Some(4), None, None, None], "{\"large\",\"jam\"}" => [None, None, Some(4), Some(7)], "{\"small\",\"egg\"}" => [None, Some(3), None, None], @@ -191,22 +214,22 @@ fn test_pivot_new() -> PolarsResult<()> { #[test] fn test_pivot_2() -> PolarsResult<()> { let df = df![ - "name"=> ["avg", "avg", "act", "test", "test"], - "err" => [Some("name1"), Some("name2"), None, Some("name1"), Some("name2")], - "wght"=> [0.0, 0.1, 1.0, 0.4, 0.2] + "index" => [Some("name1"), Some("name2"), None, Some("name1"), Some("name2")], + "columns"=> ["avg", "avg", "act", "test", "test"], + "values"=> [0.0, 0.1, 1.0, 0.4, 0.2] ]?; let out = pivot_stable( &df, - ["wght"], - ["err"], - ["name"], + ["index"], + ["columns"], + Some(["values"]), false, Some(PivotAgg::First), None, )?; let expected = df![ - "err" => [Some("name1"), Some("name2"), None], + "index" => [Some("name1"), Some("name2"), None], "avg" => [Some(0.0), Some(0.1), None], "act" => [None, None, Some(1.)], "test" => [Some(0.4), Some(0.2), None], @@ -224,22 +247,22 @@ fn test_pivot_datetime() -> PolarsResult<()> { .and_hms_opt(12, 15, 0) .unwrap(); let df = df![ - "dt" => [dt, dt, dt, dt], - "key" => ["x", "x", "y", "y"], - "val" => [100, 50, 500, -80] + "index" => [dt, dt, dt, dt], + "columns" => ["x", "x", "y", "y"], + "values" => [100, 50, 500, -80] ]?; let out = pivot( &df, - ["val"], - ["dt"], - ["key"], + ["index"], + ["columns"], + Some(["values"]), false, Some(PivotAgg::Sum), None, )?; let expected = df![ - "dt" => [dt], + "index" => [dt], "x" => [150], "y" => [420] ]?; diff --git a/crates/polars/tests/it/io/avro/mod.rs b/crates/polars/tests/it/io/avro/mod.rs new file mode 100644 index 0000000000000..e341bce42737d --- /dev/null +++ b/crates/polars/tests/it/io/avro/mod.rs @@ -0,0 +1,8 @@ +//! Read and write from and to Apache Avro + +mod read; +#[cfg(feature = "avro")] +mod read_async; +mod write; +#[cfg(feature = "avro")] +mod write_async; diff --git a/crates/polars/tests/it/io/avro/read.rs b/crates/polars/tests/it/io/avro/read.rs new file mode 100644 index 0000000000000..d15b805b19b2f --- /dev/null +++ b/crates/polars/tests/it/io/avro/read.rs @@ -0,0 +1,356 @@ +use apache_avro::types::{Record, Value}; +use apache_avro::{Codec, Days, Duration, Millis, Months, Schema as AvroSchema, Writer}; +use arrow::array::*; +use arrow::chunk::Chunk; +use arrow::datatypes::*; +use arrow::io::avro::avro_schema::read::read_metadata; +use arrow::io::avro::read; +use polars_error::PolarsResult; + +pub(super) fn schema() -> (AvroSchema, ArrowSchema) { + let raw_schema = r#" + { + "type": "record", + "name": "test", + "fields": [ + {"name": "a", "type": "long"}, + {"name": "b", "type": "string"}, + {"name": "c", "type": "int"}, + { + "name": "date", + "type": "int", + "logicalType": "date" + }, + {"name": "d", "type": "bytes"}, + {"name": "e", "type": "double"}, + {"name": "f", "type": "boolean"}, + {"name": "g", "type": ["null", "string"], "default": null}, + {"name": "h", "type": { + "type": "array", + "items": { + "name": "item", + "type": ["null", "int"], + "default": null + } + }}, + {"name": "i", "type": { + "type": "record", + "name": "bla", + "fields": [ + {"name": "e", "type": "double"} + ] + }}, + {"name": "nullable_struct", "type": [ + "null", { + "type": "record", + "name": "foo", + "fields": [ + {"name": "e", "type": "double"} + ] + }] + , "default": null + } + ] + } +"#; + + let schema = ArrowSchema::from(vec![ + Field::new("a", ArrowDataType::Int64, false), + Field::new("b", ArrowDataType::Utf8, false), + Field::new("c", ArrowDataType::Int32, false), + Field::new("date", ArrowDataType::Date32, false), + Field::new("d", ArrowDataType::Binary, false), + Field::new("e", ArrowDataType::Float64, false), + Field::new("f", ArrowDataType::Boolean, false), + Field::new("g", ArrowDataType::Utf8, true), + Field::new( + "h", + ArrowDataType::List(Box::new(Field::new("item", ArrowDataType::Int32, true))), + false, + ), + Field::new( + "i", + ArrowDataType::Struct(vec![Field::new("e", ArrowDataType::Float64, false)]), + false, + ), + Field::new( + "nullable_struct", + ArrowDataType::Struct(vec![Field::new("e", ArrowDataType::Float64, false)]), + true, + ), + ]); + + (AvroSchema::parse_str(raw_schema).unwrap(), schema) +} + +pub(super) fn data() -> Chunk> { + let data = vec![ + Some(vec![Some(1i32), None, Some(3)]), + Some(vec![Some(1i32), None, Some(3)]), + ]; + + let mut array = MutableListArray::>::new(); + array.try_extend(data).unwrap(); + + let columns = vec![ + Int64Array::from_slice([27, 47]).boxed(), + Utf8Array::::from_slice(["foo", "bar"]).boxed(), + Int32Array::from_slice([1, 1]).boxed(), + Int32Array::from_slice([1, 2]) + .to(ArrowDataType::Date32) + .boxed(), + BinaryArray::::from_slice([b"foo", b"bar"]).boxed(), + PrimitiveArray::::from_slice([1.0, 2.0]).boxed(), + BooleanArray::from_slice([true, false]).boxed(), + Utf8Array::::from([Some("foo"), None]).boxed(), + array.into_box(), + StructArray::new( + ArrowDataType::Struct(vec![Field::new("e", ArrowDataType::Float64, false)]), + vec![PrimitiveArray::::from_slice([1.0, 2.0]).boxed()], + None, + ) + .boxed(), + StructArray::new( + ArrowDataType::Struct(vec![Field::new("e", ArrowDataType::Float64, false)]), + vec![PrimitiveArray::::from_slice([1.0, 0.0]).boxed()], + Some([true, false].into()), + ) + .boxed(), + ]; + + Chunk::try_new(columns).unwrap() +} + +pub(super) fn write_avro(codec: Codec) -> Result, apache_avro::Error> { + let (avro, _) = schema(); + // a writer needs a schema and something to write to + let mut writer = Writer::with_codec(&avro, Vec::new(), codec); + + // the Record type models our Record schema + let mut record = Record::new(writer.schema()).unwrap(); + record.put("a", 27i64); + record.put("b", "foo"); + record.put("c", 1i32); + record.put("date", 1i32); + record.put("d", b"foo".as_ref()); + record.put("e", 1.0f64); + record.put("f", true); + record.put("g", Some("foo")); + record.put( + "h", + Value::Array(vec![ + Value::Union(1, Box::new(Value::Int(1))), + Value::Union(0, Box::new(Value::Null)), + Value::Union(1, Box::new(Value::Int(3))), + ]), + ); + record.put( + "i", + Value::Record(vec![("e".to_string(), Value::Double(1.0f64))]), + ); + record.put( + "duration", + Value::Duration(Duration::new(Months::new(1), Days::new(1), Millis::new(1))), + ); + record.put( + "nullable_struct", + Value::Union( + 1, + Box::new(Value::Record(vec![( + "e".to_string(), + Value::Double(1.0f64), + )])), + ), + ); + writer.append(record)?; + + let mut record = Record::new(writer.schema()).unwrap(); + record.put("b", "bar"); + record.put("a", 47i64); + record.put("c", 1i32); + record.put("date", 2i32); + record.put("d", b"bar".as_ref()); + record.put("e", 2.0f64); + record.put("f", false); + record.put("g", None::<&str>); + record.put( + "i", + Value::Record(vec![("e".to_string(), Value::Double(2.0f64))]), + ); + record.put( + "h", + Value::Array(vec![ + Value::Union(1, Box::new(Value::Int(1))), + Value::Union(0, Box::new(Value::Null)), + Value::Union(1, Box::new(Value::Int(3))), + ]), + ); + record.put("nullable_struct", Value::Union(0, Box::new(Value::Null))); + writer.append(record)?; + writer.into_inner() +} + +pub(super) fn read_avro( + mut avro: &[u8], + projection: Option>, +) -> PolarsResult<(Chunk>, ArrowSchema)> { + let file = &mut avro; + + let metadata = read_metadata(file)?; + let schema = read::infer_schema(&metadata.record)?; + + let mut reader = read::Reader::new(file, metadata, schema.fields.clone(), projection.clone()); + + let schema = if let Some(projection) = projection { + let fields = schema + .fields + .into_iter() + .zip(projection.iter()) + .filter_map(|x| if *x.1 { Some(x.0) } else { None }) + .collect::>(); + ArrowSchema::from(fields) + } else { + schema + }; + + reader.next().unwrap().map(|x| (x, schema)) +} + +fn test(codec: Codec) -> PolarsResult<()> { + let avro = write_avro(codec).unwrap(); + let expected = data(); + let (_, expected_schema) = schema(); + + let (result, schema) = read_avro(&avro, None)?; + + assert_eq!(schema, expected_schema); + assert_eq!(result, expected); + Ok(()) +} + +#[test] +fn read_without_codec() -> PolarsResult<()> { + test(Codec::Null) +} + +#[cfg(feature = "io_avro_compression")] +#[test] +fn read_deflate() -> PolarsResult<()> { + test(Codec::Deflate) +} + +#[cfg(feature = "io_avro_compression")] +#[test] +fn read_snappy() -> PolarsResult<()> { + test(Codec::Snappy) +} + +#[test] +fn test_projected() -> PolarsResult<()> { + let expected = data(); + let (_, expected_schema) = schema(); + + let avro = write_avro(Codec::Null).unwrap(); + + for i in 0..expected_schema.fields.len() { + let mut projection = vec![false; expected_schema.fields.len()]; + projection[i] = true; + + let expected = expected + .clone() + .into_arrays() + .into_iter() + .zip(projection.iter()) + .filter_map(|x| if *x.1 { Some(x.0) } else { None }) + .collect(); + let expected = Chunk::new(expected); + + let expected_fields = expected_schema + .clone() + .fields + .into_iter() + .zip(projection.iter()) + .filter_map(|x| if *x.1 { Some(x.0) } else { None }) + .collect::>(); + let expected_schema = ArrowSchema::from(expected_fields); + + let (result, schema) = read_avro(&avro, Some(projection))?; + + assert_eq!(schema, expected_schema); + assert_eq!(result, expected); + } + Ok(()) +} + +fn schema_list() -> (AvroSchema, ArrowSchema) { + let raw_schema = r#" + { + "type": "record", + "name": "test", + "fields": [ + {"name": "h", "type": { + "type": "array", + "items": { + "name": "item", + "type": "int" + } + }} + ] + } +"#; + + let schema = ArrowSchema::from(vec![Field::new( + "h", + ArrowDataType::List(Box::new(Field::new("item", ArrowDataType::Int32, false))), + false, + )]); + + (AvroSchema::parse_str(raw_schema).unwrap(), schema) +} + +pub(super) fn data_list() -> Chunk> { + let data = [Some(vec![Some(1i32), Some(2), Some(3)]), Some(vec![])]; + + let mut array = MutableListArray::>::new_from( + Default::default(), + ArrowDataType::List(Box::new(Field::new("item", ArrowDataType::Int32, false))), + 0, + ); + array.try_extend(data).unwrap(); + + let columns = vec![array.into_box()]; + + Chunk::try_new(columns).unwrap() +} + +pub(super) fn write_list(codec: Codec) -> Result, apache_avro::Error> { + let (avro, _) = schema_list(); + // a writer needs a schema and something to write to + let mut writer = Writer::with_codec(&avro, Vec::new(), codec); + + // the Record type models our Record schema + let mut record = Record::new(writer.schema()).unwrap(); + record.put( + "h", + Value::Array(vec![Value::Int(1), Value::Int(2), Value::Int(3)]), + ); + writer.append(record)?; + + let mut record = Record::new(writer.schema()).unwrap(); + record.put("h", Value::Array(vec![])); + writer.append(record)?; + Ok(writer.into_inner().unwrap()) +} + +#[test] +fn test_list() -> PolarsResult<()> { + let avro = write_list(Codec::Null).unwrap(); + let expected = data_list(); + let (_, expected_schema) = schema_list(); + + let (result, schema) = read_avro(&avro, None)?; + + assert_eq!(schema, expected_schema); + assert_eq!(result, expected); + Ok(()) +} diff --git a/crates/polars/tests/it/io/avro/read_async.rs b/crates/polars/tests/it/io/avro/read_async.rs new file mode 100644 index 0000000000000..d50fd7595c58c --- /dev/null +++ b/crates/polars/tests/it/io/avro/read_async.rs @@ -0,0 +1,42 @@ +use apache_avro::Codec; +use arrow::io::avro::avro_schema::read_async::{block_stream, read_metadata}; +use arrow::io::avro::read; +use futures::{pin_mut, StreamExt}; +use polars_error::PolarsResult; + +use super::read::{schema, write_avro}; + +async fn test(codec: Codec) -> PolarsResult<()> { + let avro_data = write_avro(codec).unwrap(); + let (_, expected_schema) = schema(); + + let mut reader = &mut &avro_data[..]; + + let metadata = read_metadata(&mut reader).await?; + let schema = read::infer_schema(&metadata.record)?; + + assert_eq!(schema, expected_schema); + + let blocks = block_stream(&mut reader, metadata.marker).await; + + pin_mut!(blocks); + while let Some(block) = blocks.next().await.transpose()? { + assert!(block.number_of_rows > 0 || block.data.is_empty()) + } + Ok(()) +} + +#[tokio::test] +async fn read_without_codec() -> PolarsResult<()> { + test(Codec::Null).await +} + +#[tokio::test] +async fn read_deflate() -> PolarsResult<()> { + test(Codec::Deflate).await +} + +#[tokio::test] +async fn read_snappy() -> PolarsResult<()> { + test(Codec::Snappy).await +} diff --git a/crates/polars/tests/it/io/avro/write.rs b/crates/polars/tests/it/io/avro/write.rs new file mode 100644 index 0000000000000..dd8058aa52dcc --- /dev/null +++ b/crates/polars/tests/it/io/avro/write.rs @@ -0,0 +1,372 @@ +use arrow::array::*; +use arrow::chunk::Chunk; +use arrow::datatypes::*; +use arrow::io::avro::avro_schema::file::{Block, CompressedBlock, Compression}; +use arrow::io::avro::avro_schema::write::{compress, write_block, write_metadata}; +use arrow::io::avro::write; +use avro_schema::schema::{Field as AvroField, Record, Schema as AvroSchema}; +use polars_error::PolarsResult; + +use super::read::read_avro; + +pub(super) fn schema() -> ArrowSchema { + ArrowSchema::from(vec![ + Field::new("int64", ArrowDataType::Int64, false), + Field::new("int64 nullable", ArrowDataType::Int64, true), + Field::new("utf8", ArrowDataType::Utf8, false), + Field::new("utf8 nullable", ArrowDataType::Utf8, true), + Field::new("int32", ArrowDataType::Int32, false), + Field::new("int32 nullable", ArrowDataType::Int32, true), + Field::new("date", ArrowDataType::Date32, false), + Field::new("date nullable", ArrowDataType::Date32, true), + Field::new("binary", ArrowDataType::Binary, false), + Field::new("binary nullable", ArrowDataType::Binary, true), + Field::new("float32", ArrowDataType::Float32, false), + Field::new("float32 nullable", ArrowDataType::Float32, true), + Field::new("float64", ArrowDataType::Float64, false), + Field::new("float64 nullable", ArrowDataType::Float64, true), + Field::new("boolean", ArrowDataType::Boolean, false), + Field::new("boolean nullable", ArrowDataType::Boolean, true), + Field::new( + "list", + ArrowDataType::List(Box::new(Field::new("item", ArrowDataType::Int32, true))), + false, + ), + Field::new( + "list nullable", + ArrowDataType::List(Box::new(Field::new("item", ArrowDataType::Int32, true))), + true, + ), + ]) +} + +pub(super) fn data() -> Chunk> { + let list_dt = ArrowDataType::List(Box::new(Field::new("item", ArrowDataType::Int32, true))); + let list_dt1 = ArrowDataType::List(Box::new(Field::new("item", ArrowDataType::Int32, true))); + + let columns = vec![ + Box::new(Int64Array::from_slice([27, 47])) as Box, + Box::new(Int64Array::from([Some(27), None])), + Box::new(Utf8Array::::from_slice(["foo", "bar"])), + Box::new(Utf8Array::::from([Some("foo"), None])), + Box::new(Int32Array::from_slice([1, 1])), + Box::new(Int32Array::from([Some(1), None])), + Box::new(Int32Array::from_slice([1, 2]).to(ArrowDataType::Date32)), + Box::new(Int32Array::from([Some(1), None]).to(ArrowDataType::Date32)), + Box::new(BinaryArray::::from_slice([b"foo", b"bar"])), + Box::new(BinaryArray::::from([Some(b"foo"), None])), + Box::new(PrimitiveArray::::from_slice([1.0, 2.0])), + Box::new(PrimitiveArray::::from([Some(1.0), None])), + Box::new(PrimitiveArray::::from_slice([1.0, 2.0])), + Box::new(PrimitiveArray::::from([Some(1.0), None])), + Box::new(BooleanArray::from_slice([true, false])), + Box::new(BooleanArray::from([Some(true), None])), + Box::new(ListArray::::new( + list_dt, + vec![0, 2, 5].try_into().unwrap(), + Box::new(PrimitiveArray::::from([ + None, + Some(1), + None, + Some(3), + Some(4), + ])), + None, + )), + Box::new(ListArray::::new( + list_dt1, + vec![0, 2, 2].try_into().unwrap(), + Box::new(PrimitiveArray::::from([None, Some(1)])), + Some([true, false].into()), + )), + ]; + + Chunk::new(columns) +} + +pub(super) fn serialize_to_block>( + columns: &Chunk, + schema: &ArrowSchema, + compression: Option, +) -> PolarsResult { + let record = write::to_record(schema, "".to_string())?; + + let mut serializers = columns + .arrays() + .iter() + .map(|x| x.as_ref()) + .zip(record.fields.iter()) + .map(|(array, field)| write::new_serializer(array, &field.schema)) + .collect::>(); + let mut block = Block::new(columns.len(), vec![]); + + write::serialize(&mut serializers, &mut block); + + let mut compressed_block = CompressedBlock::default(); + + compress(&mut block, &mut compressed_block, compression)?; + + Ok(compressed_block) +} + +fn write_avro>( + columns: &Chunk, + schema: &ArrowSchema, + compression: Option, +) -> PolarsResult> { + let compressed_block = serialize_to_block(columns, schema, compression)?; + + let avro_fields = write::to_record(schema, "".to_string())?; + let mut file = vec![]; + + write_metadata(&mut file, avro_fields, compression)?; + + write_block(&mut file, &compressed_block)?; + + Ok(file) +} + +fn roundtrip(compression: Option) -> PolarsResult<()> { + let expected = data(); + let expected_schema = schema(); + + let data = write_avro(&expected, &expected_schema, compression)?; + + let (result, read_schema) = read_avro(&data, None)?; + + assert_eq!(expected_schema, read_schema); + for (c1, c2) in result.columns().iter().zip(expected.columns().iter()) { + assert_eq!(c1.as_ref(), c2.as_ref()); + } + Ok(()) +} + +#[test] +fn no_compression() -> PolarsResult<()> { + roundtrip(None) +} + +#[cfg(feature = "io_avro_compression")] +#[test] +fn snappy() -> PolarsResult<()> { + roundtrip(Some(Compression::Snappy)) +} + +#[cfg(feature = "io_avro_compression")] +#[test] +fn deflate() -> PolarsResult<()> { + roundtrip(Some(Compression::Deflate)) +} + +fn large_format_schema() -> ArrowSchema { + ArrowSchema::from(vec![ + Field::new("large_utf8", ArrowDataType::LargeUtf8, false), + Field::new("large_utf8_nullable", ArrowDataType::LargeUtf8, true), + Field::new("large_binary", ArrowDataType::LargeBinary, false), + Field::new("large_binary_nullable", ArrowDataType::LargeBinary, true), + ]) +} + +fn large_format_data() -> Chunk> { + let columns = vec![ + Box::new(Utf8Array::::from_slice(["a", "b"])) as Box, + Box::new(Utf8Array::::from([Some("a"), None])), + Box::new(BinaryArray::::from_slice([b"foo", b"bar"])), + Box::new(BinaryArray::::from([Some(b"foo"), None])), + ]; + Chunk::new(columns) +} + +fn large_format_expected_schema() -> ArrowSchema { + ArrowSchema::from(vec![ + Field::new("large_utf8", ArrowDataType::Utf8, false), + Field::new("large_utf8_nullable", ArrowDataType::Utf8, true), + Field::new("large_binary", ArrowDataType::Binary, false), + Field::new("large_binary_nullable", ArrowDataType::Binary, true), + ]) +} + +fn large_format_expected_data() -> Chunk> { + let columns = vec![ + Box::new(Utf8Array::::from_slice(["a", "b"])) as Box, + Box::new(Utf8Array::::from([Some("a"), None])), + Box::new(BinaryArray::::from_slice([b"foo", b"bar"])), + Box::new(BinaryArray::::from([Some(b"foo"), None])), + ]; + Chunk::new(columns) +} + +#[test] +fn check_large_format() -> PolarsResult<()> { + let write_schema = large_format_schema(); + let write_data = large_format_data(); + + let data = write_avro(&write_data, &write_schema, None)?; + let (result, read_schame) = read_avro(&data, None)?; + + let expected_schema = large_format_expected_schema(); + assert_eq!(read_schame, expected_schema); + + let expected_data = large_format_expected_data(); + for (c1, c2) in result.columns().iter().zip(expected_data.columns().iter()) { + assert_eq!(c1.as_ref(), c2.as_ref()); + } + + Ok(()) +} + +fn struct_schema() -> ArrowSchema { + ArrowSchema::from(vec![ + Field::new( + "struct", + ArrowDataType::Struct(vec![ + Field::new("item1", ArrowDataType::Int32, false), + Field::new("item2", ArrowDataType::Int32, true), + ]), + false, + ), + Field::new( + "struct nullable", + ArrowDataType::Struct(vec![ + Field::new("item1", ArrowDataType::Int32, false), + Field::new("item2", ArrowDataType::Int32, true), + ]), + true, + ), + ]) +} + +fn struct_data() -> Chunk> { + let struct_dt = ArrowDataType::Struct(vec![ + Field::new("item1", ArrowDataType::Int32, false), + Field::new("item2", ArrowDataType::Int32, true), + ]); + + Chunk::new(vec![ + Box::new(StructArray::new( + struct_dt.clone(), + vec![ + Box::new(PrimitiveArray::::from_slice([1, 2])), + Box::new(PrimitiveArray::::from([None, Some(1)])), + ], + None, + )), + Box::new(StructArray::new( + struct_dt, + vec![ + Box::new(PrimitiveArray::::from_slice([1, 2])), + Box::new(PrimitiveArray::::from([None, Some(1)])), + ], + Some([true, false].into()), + )), + ]) +} + +fn avro_record() -> Record { + Record { + name: "".to_string(), + namespace: None, + doc: None, + aliases: vec![], + fields: vec![ + AvroField { + name: "struct".to_string(), + doc: None, + schema: AvroSchema::Record(Record { + name: "r1".to_string(), + namespace: None, + doc: None, + aliases: vec![], + fields: vec![ + AvroField { + name: "item1".to_string(), + doc: None, + schema: AvroSchema::Int(None), + default: None, + order: None, + aliases: vec![], + }, + AvroField { + name: "item2".to_string(), + doc: None, + schema: AvroSchema::Union(vec![ + AvroSchema::Null, + AvroSchema::Int(None), + ]), + default: None, + order: None, + aliases: vec![], + }, + ], + }), + default: None, + order: None, + aliases: vec![], + }, + AvroField { + name: "struct nullable".to_string(), + doc: None, + schema: AvroSchema::Union(vec![ + AvroSchema::Null, + AvroSchema::Record(Record { + name: "r2".to_string(), + namespace: None, + doc: None, + aliases: vec![], + fields: vec![ + AvroField { + name: "item1".to_string(), + doc: None, + schema: AvroSchema::Int(None), + default: None, + order: None, + aliases: vec![], + }, + AvroField { + name: "item2".to_string(), + doc: None, + schema: AvroSchema::Union(vec![ + AvroSchema::Null, + AvroSchema::Int(None), + ]), + default: None, + order: None, + aliases: vec![], + }, + ], + }), + ]), + default: None, + order: None, + aliases: vec![], + }, + ], + } +} + +#[test] +fn avro_record_schema() -> PolarsResult<()> { + let arrow_schema = struct_schema(); + let record = write::to_record(&arrow_schema, "".to_string())?; + assert_eq!(record, avro_record()); + Ok(()) +} + +#[test] +fn struct_() -> PolarsResult<()> { + let write_schema = struct_schema(); + let write_data = struct_data(); + + let data = write_avro(&write_data, &write_schema, None)?; + let (result, read_schema) = read_avro(&data, None)?; + + let expected_schema = struct_schema(); + assert_eq!(read_schema, expected_schema); + + let expected_data = struct_data(); + for (c1, c2) in result.columns().iter().zip(expected_data.columns().iter()) { + assert_eq!(c1.as_ref(), c2.as_ref()); + } + + Ok(()) +} diff --git a/crates/polars/tests/it/io/avro/write_async.rs b/crates/polars/tests/it/io/avro/write_async.rs new file mode 100644 index 0000000000000..7c04873af64cb --- /dev/null +++ b/crates/polars/tests/it/io/avro/write_async.rs @@ -0,0 +1,48 @@ +use arrow::array::*; +use arrow::chunk::Chunk; +use arrow::datatypes::*; +use arrow::io::avro::write; +use avro_schema::file::Compression; +use avro_schema::write_async::{write_block, write_metadata}; +use polars_error::PolarsResult; + +use super::read::read_avro; +use super::write::{data, schema, serialize_to_block}; + +async fn write_avro>( + columns: &Chunk, + schema: &ArrowSchema, + compression: Option, +) -> PolarsResult> { + // usually done on a different thread pool + let compressed_block = serialize_to_block(columns, schema, compression)?; + + let record = write::to_record(schema, "".to_string())?; + let mut file = vec![]; + + write_metadata(&mut file, record, compression).await?; + + write_block(&mut file, &compressed_block).await?; + + Ok(file) +} + +async fn roundtrip(compression: Option) -> PolarsResult<()> { + let expected = data(); + let expected_schema = schema(); + + let data = write_avro(&expected, &expected_schema, compression).await?; + + let (result, read_schema) = read_avro(&data, None)?; + + assert_eq!(expected_schema, read_schema); + for (c1, c2) in result.columns().iter().zip(expected.columns().iter()) { + assert_eq!(c1.as_ref(), c2.as_ref()); + } + Ok(()) +} + +#[tokio::test] +async fn no_compression() -> PolarsResult<()> { + roundtrip(None).await +} diff --git a/crates/polars/tests/it/io/csv.rs b/crates/polars/tests/it/io/csv.rs index c855ad45f4c00..c180110463599 100644 --- a/crates/polars/tests/it/io/csv.rs +++ b/crates/polars/tests/it/io/csv.rs @@ -256,7 +256,7 @@ fn test_newline_in_custom_quote_char() { #[test] fn test_escape_2() { - // this is is harder than it looks. + // this is harder than it looks. // Fields: // * hello // * "," diff --git a/crates/polars/tests/it/io/ipc.rs b/crates/polars/tests/it/io/ipc.rs new file mode 100644 index 0000000000000..caf36c43f4fa1 --- /dev/null +++ b/crates/polars/tests/it/io/ipc.rs @@ -0,0 +1,23 @@ +use std::io::{Seek, SeekFrom}; + +use polars::prelude::*; + +#[test] +fn test_ipc_compression_variadic_buffers() { + let mut df = df![ + "foo" => std::iter::repeat("Home delivery vat 24 %").take(3).collect::>() + ] + .unwrap(); + + let mut file = std::io::Cursor::new(vec![]); + IpcWriter::new(&mut file) + .with_compression(Some(IpcCompression::LZ4)) + .with_pl_flavor(true) + .finish(&mut df) + .unwrap(); + + file.seek(SeekFrom::Start(0)).unwrap(); + let out = IpcReader::new(file).finish().unwrap(); + + assert_eq!(out.shape(), (3, 1)); +} diff --git a/crates/polars/tests/it/io/ipc_stream.rs b/crates/polars/tests/it/io/ipc_stream.rs index 2010db6b543c0..18d67990cb536 100644 --- a/crates/polars/tests/it/io/ipc_stream.rs +++ b/crates/polars/tests/it/io/ipc_stream.rs @@ -2,95 +2,112 @@ mod test { use std::io::Cursor; - use polars_core::df; use polars_core::prelude::*; + use polars_core::{assert_df_eq, df}; use polars_io::ipc::*; use polars_io::{SerReader, SerWriter}; use crate::io::create_df; - #[test] - fn write_and_read_ipc_stream() { - // Vec : Write + Read - // Cursor>: Seek + fn create_ipc_stream(mut df: DataFrame) -> Cursor> { let mut buf: Cursor> = Cursor::new(Vec::new()); - let mut df = create_df(); IpcStreamWriter::new(&mut buf) .finish(&mut df) - .expect("ipc writer"); + .expect("failed to write ICP stream"); buf.set_position(0); - let df_read = IpcStreamReader::new(buf).finish().unwrap(); - assert!(df.equals(&df_read)); + buf + } + + #[test] + fn write_and_read_ipc_stream() { + let df = create_df(); + + let reader = create_ipc_stream(df); + + let actual = IpcStreamReader::new(reader).finish().unwrap(); + + let expected = create_df(); + assert_df_eq!(actual, expected); } #[test] fn test_read_ipc_stream_with_projection() { - let mut buf: Cursor> = Cursor::new(Vec::new()); - let mut df = df!("a" => [1, 2, 3], "b" => [2, 3, 4], "c" => [3, 4, 5]).unwrap(); + let df = df!( + "a" => [1], + "b" => [2], + "c" => [3], + ) + .unwrap(); - IpcStreamWriter::new(&mut buf) - .finish(&mut df) - .expect("ipc writer"); - buf.set_position(0); + let reader = create_ipc_stream(df); - let expected = df!("b" => [2, 3, 4], "c" => [3, 4, 5]).unwrap(); - let df_read = IpcStreamReader::new(buf) + let actual = IpcStreamReader::new(reader) .with_projection(Some(vec![1, 2])) .finish() .unwrap(); - assert_eq!(df_read.shape(), (3, 2)); - df_read.equals(&expected); + + let expected = df!( + "b" => [2], + "c" => [3], + ) + .unwrap(); + assert_df_eq!(actual, expected); } #[test] fn test_read_ipc_stream_with_columns() { - let mut buf: Cursor> = Cursor::new(Vec::new()); - let mut df = df!("a" => [1, 2, 3], "b" => [2, 3, 4], "c" => [3, 4, 5]).unwrap(); + let df = df!( + "a" => [1], + "b" => [2], + "c" => [3], + ) + .unwrap(); - IpcStreamWriter::new(&mut buf) - .finish(&mut df) - .expect("ipc writer"); - buf.set_position(0); + let reader = create_ipc_stream(df); - let expected = df!("b" => [2, 3, 4], "c" => [3, 4, 5]).unwrap(); - let df_read = IpcStreamReader::new(buf) + let actual = IpcStreamReader::new(reader) .with_columns(Some(vec!["c".to_string(), "b".to_string()])) .finish() .unwrap(); - df_read.equals(&expected); - let mut buf: Cursor> = Cursor::new(Vec::new()); - let mut df = df![ - "a" => ["x", "y", "z"], - "b" => [123, 456, 789], - "c" => [4.5, 10.0, 10.0], - "d" => ["misc", "other", "value"], - ] + let expected = df!( + "c" => [3], + "b" => [2], + ) .unwrap(); - IpcStreamWriter::new(&mut buf) - .finish(&mut df) - .expect("ipc writer"); - buf.set_position(0); - let expected = df![ - "a" => ["x", "y", "z"], - "c" => [4.5, 10.0, 10.0], - "d" => ["misc", "other", "value"], - "b" => [123, 456, 789], + assert_df_eq!(actual, expected); + } + + #[test] + fn test_read_ipc_stream_with_columns_reorder() { + let df = df![ + "a" => [1], + "b" => [2], + "c" => [3], ] .unwrap(); - let df_read = IpcStreamReader::new(buf) + + let reader = create_ipc_stream(df); + + let actual = IpcStreamReader::new(reader) .with_columns(Some(vec![ - "a".to_string(), - "c".to_string(), - "d".to_string(), "b".to_string(), + "c".to_string(), + "a".to_string(), ])) .finish() .unwrap(); - df_read.equals(&expected); + + let expected = df![ + "b" => [2], + "c" => [3], + "a" => [1], + ] + .unwrap(); + assert_df_eq!(actual, expected); } #[test] @@ -101,38 +118,39 @@ mod test { } #[test] - fn test_write_with_compression() { - let mut df = create_df(); - - let compressions = vec![None, Some(IpcCompression::LZ4), Some(IpcCompression::ZSTD)]; - - for compression in compressions.into_iter() { - let mut buf: Cursor> = Cursor::new(Vec::new()); - IpcStreamWriter::new(&mut buf) - .with_compression(compression) - .finish(&mut df) - .expect("ipc writer"); - buf.set_position(0); - - let df_read = IpcStreamReader::new(buf) - .finish() - .unwrap_or_else(|_| panic!("IPC reader: {:?}", compression)); - assert!(df.equals(&df_read)); - } + fn test_write_with_lz4_compression() { + test_write_with_compression(IpcCompression::LZ4); + } + + #[test] + fn test_write_with_zstd_compression() { + test_write_with_compression(IpcCompression::ZSTD); + } + + fn test_write_with_compression(compression: IpcCompression) { + let reader = { + let mut writer: Cursor> = Cursor::new(Vec::new()); + IpcStreamWriter::new(&mut writer) + .with_compression(Some(compression)) + .finish(&mut create_df()) + .unwrap(); + writer.set_position(0); + writer + }; + + let actual = IpcStreamReader::new(reader).finish().unwrap(); + assert_df_eq!(actual, create_df()); } #[test] fn write_and_read_ipc_stream_empty_series() { - let mut buf: Cursor> = Cursor::new(Vec::new()); - let chunked_array = Float64Chunked::new("empty", &[0_f64; 0]); - let mut df = DataFrame::new(vec![chunked_array.into_series()]).unwrap(); - IpcStreamWriter::new(&mut buf) - .finish(&mut df) - .expect("ipc writer"); + fn df() -> DataFrame { + DataFrame::new(vec![Float64Chunked::new("empty", &[0_f64; 0]).into_series()]).unwrap() + } - buf.set_position(0); + let reader = create_ipc_stream(df()); - let df_read = IpcStreamReader::new(buf).finish().unwrap(); - assert!(df.equals(&df_read)); + let actual = IpcStreamReader::new(reader).finish().unwrap(); + assert_df_eq!(df(), actual); } } diff --git a/crates/polars/tests/it/io/mod.rs b/crates/polars/tests/it/io/mod.rs index 7e384e214d4c4..4835171721c99 100644 --- a/crates/polars/tests/it/io/mod.rs +++ b/crates/polars/tests/it/io/mod.rs @@ -6,6 +6,11 @@ mod json; #[cfg(feature = "parquet")] mod parquet; +#[cfg(feature = "avro")] +mod avro; + +#[cfg(feature = "ipc")] +mod ipc; #[cfg(feature = "ipc_streaming")] mod ipc_stream; diff --git a/crates/polars/tests/it/io/parquet.rs b/crates/polars/tests/it/io/parquet.rs deleted file mode 100644 index ad5349ced1519..0000000000000 --- a/crates/polars/tests/it/io/parquet.rs +++ /dev/null @@ -1,20 +0,0 @@ -use std::io::Cursor; - -use polars::prelude::*; - -#[test] -fn test_vstack_empty_3220() -> PolarsResult<()> { - let df1 = df! { - "a" => ["1", "2"], - "b" => [1, 2] - }?; - let empty_df = df1.head(Some(0)); - let mut stacked = df1.clone(); - stacked.vstack_mut(&empty_df)?; - stacked.vstack_mut(&df1)?; - let mut buf = Cursor::new(Vec::new()); - ParquetWriter::new(&mut buf).finish(&mut stacked)?; - let read_df = ParquetReader::new(buf).finish()?; - assert!(stacked.equals(&read_df)); - Ok(()) -} diff --git a/crates/polars/tests/it/io/parquet/arrow/integration.rs b/crates/polars/tests/it/io/parquet/arrow/integration.rs new file mode 100644 index 0000000000000..7f84c433b0d55 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/arrow/integration.rs @@ -0,0 +1,41 @@ +use arrow2::error::Result; + +use super::{integration_read, integration_write}; +use crate::io::ipc::read_gzip_json; + +fn test_file(version: &str, file_name: &str) -> Result<()> { + let (schema, _, batches) = read_gzip_json(version, file_name)?; + + // empty batches are not written/read from parquet and can be ignored + let batches = batches + .into_iter() + .filter(|x| !x.is_empty()) + .collect::>(); + + let data = integration_write(&schema, &batches)?; + + let (read_schema, read_batches) = integration_read(&data, None)?; + + assert_eq!(schema, read_schema); + assert_eq!(batches, read_batches); + + Ok(()) +} + +#[test] +fn roundtrip_100_primitive() -> Result<()> { + test_file("1.0.0-littleendian", "generated_primitive")?; + test_file("1.0.0-bigendian", "generated_primitive") +} + +#[test] +fn roundtrip_100_dict() -> Result<()> { + test_file("1.0.0-littleendian", "generated_dictionary")?; + test_file("1.0.0-bigendian", "generated_dictionary") +} + +#[test] +fn roundtrip_100_extension() -> Result<()> { + test_file("1.0.0-littleendian", "generated_extension")?; + test_file("1.0.0-bigendian", "generated_extension") +} diff --git a/crates/polars/tests/it/io/parquet/arrow/mod.rs b/crates/polars/tests/it/io/parquet/arrow/mod.rs new file mode 100644 index 0000000000000..d7832a6567ead --- /dev/null +++ b/crates/polars/tests/it/io/parquet/arrow/mod.rs @@ -0,0 +1,1661 @@ +use std::io::{Cursor, Read, Seek}; + +use arrow::array::*; +use arrow::bitmap::Bitmap; +use arrow::chunk::Chunk; +use arrow::datatypes::*; +use arrow::legacy::prelude::LargeListArray; +use arrow::types::{i256, NativeType}; +use ethnum::AsI256; +use polars_error::PolarsResult; +use polars_parquet::read as p_read; +use polars_parquet::read::statistics::*; +use polars_parquet::write::*; + +#[cfg(feature = "io_json_integration")] +mod integration; +mod read; +mod read_indexes; +mod write; + +#[cfg(feature = "io_parquet_sample_test")] +mod sample_tests; + +type ArrayStats = (Box, Statistics); + +fn new_struct( + arrays: Vec>, + names: Vec, + validity: Option, +) -> StructArray { + let fields = names + .into_iter() + .zip(arrays.iter()) + .map(|(n, a)| Field::new(n, a.data_type().clone(), true)) + .collect(); + StructArray::new(ArrowDataType::Struct(fields), arrays, validity) +} + +pub fn read_column(mut reader: R, column: &str) -> PolarsResult { + let metadata = p_read::read_metadata(&mut reader)?; + let schema = p_read::infer_schema(&metadata)?; + + let row_group = &metadata.row_groups[0]; + + // verify that we can read indexes + if p_read::indexes::has_indexes(row_group) { + let _indexes = p_read::indexes::read_filtered_pages( + &mut reader, + row_group, + &schema.fields, + |_, _| vec![], + )?; + } + + let schema = schema.filter(|_, f| f.name == column); + + let field = &schema.fields[0]; + + let statistics = deserialize(field, row_group)?; + + let mut reader = p_read::FileReader::new(reader, metadata.row_groups, schema, None, None, None); + + let array = reader.next().unwrap()?.into_arrays().pop().unwrap(); + + Ok((array, statistics)) +} + +pub fn pyarrow_nested_edge(column: &str) -> Box { + match column { + "simple" => { + // [[0, 1]] + let data = [Some(vec![Some(0), Some(1)])]; + let mut a = MutableListArray::>::new(); + a.try_extend(data).unwrap(); + let array: ListArray = a.into(); + Box::new(array) + }, + "null" => { + // [None] + let data = [None::>>]; + let mut a = MutableListArray::>::new(); + a.try_extend(data).unwrap(); + let array: ListArray = a.into(); + Box::new(array) + }, + "empty" => { + // [None] + let data: [Option>>; 0] = []; + let mut a = MutableListArray::>::new(); + a.try_extend(data).unwrap(); + let array: ListArray = a.into(); + Box::new(array) + }, + "struct_list_nullable" => { + // [ + // {"f1": ["a", "b", None, "c"]} + // ] + let a = ListArray::::new( + ArrowDataType::LargeList(Box::new(Field::new( + "item", + ArrowDataType::Utf8View, + true, + ))), + vec![0, 4].try_into().unwrap(), + Utf8ViewArray::from_slice([Some("a"), Some("b"), None, Some("c")]).boxed(), + None, + ); + StructArray::new( + ArrowDataType::Struct(vec![Field::new("f1", a.data_type().clone(), true)]), + vec![a.boxed()], + None, + ) + .boxed() + }, + "list_struct_list_nullable" => { + let values = pyarrow_nested_edge("struct_list_nullable"); + ListArray::::new( + ArrowDataType::LargeList(Box::new(Field::new( + "item", + values.data_type().clone(), + true, + ))), + vec![0, 1].try_into().unwrap(), + values, + None, + ) + .boxed() + }, + _ => todo!(), + } +} + +pub fn pyarrow_nested_nullable(column: &str) -> Box { + let i64_values = &[ + Some(0), + Some(1), + Some(2), + None, + Some(3), + Some(4), + Some(5), + Some(6), + Some(7), + Some(8), + Some(9), + Some(10), + ]; + let offsets = vec![0, 2, 2, 5, 8, 8, 11, 11, 12].try_into().unwrap(); + + let values = match column { + "list_int64" => { + // [[0, 1], None, [2, None, 3], [4, 5, 6], [], [7, 8, 9], None, [10]] + PrimitiveArray::::from(i64_values).boxed() + }, + "list_int64_required" | "list_int64_optional_required" | "list_int64_required_required" => { + // [[0, 1], None, [2, 0, 3], [4, 5, 6], [], [7, 8, 9], None, [10]] + PrimitiveArray::::from(&[ + Some(0), + Some(1), + Some(2), + Some(0), + Some(3), + Some(4), + Some(5), + Some(6), + Some(7), + Some(8), + Some(9), + Some(10), + ]) + .boxed() + }, + "list_int16" => PrimitiveArray::::from(&[ + Some(0), + Some(1), + Some(2), + None, + Some(3), + Some(4), + Some(5), + Some(6), + Some(7), + Some(8), + Some(9), + Some(10), + ]) + .boxed(), + "list_bool" => BooleanArray::from(&[ + Some(false), + Some(true), + Some(true), + None, + Some(false), + Some(true), + Some(false), + Some(true), + Some(false), + Some(false), + Some(false), + Some(true), + ]) + .boxed(), + /* + string = [ + ["Hello", "bbb"], + None, + ["aa", None, ""], + ["bbb", "aa", "ccc"], + [], + ["abc", "bbb", "bbb"], + None, + [""], + ] + */ + "list_utf8" => Utf8ViewArray::from_slice([ + Some("Hello".to_string()), + Some("bbb".to_string()), + Some("aa".to_string()), + None, + Some("".to_string()), + Some("bbb".to_string()), + Some("aa".to_string()), + Some("ccc".to_string()), + Some("abc".to_string()), + Some("bbb".to_string()), + Some("bbb".to_string()), + Some("".to_string()), + ]) + .boxed(), + "list_large_binary" => Box::new(BinaryArray::::from([ + Some(b"Hello".to_vec()), + Some(b"bbb".to_vec()), + Some(b"aa".to_vec()), + None, + Some(b"".to_vec()), + Some(b"bbb".to_vec()), + Some(b"aa".to_vec()), + Some(b"ccc".to_vec()), + Some(b"abc".to_vec()), + Some(b"bbb".to_vec()), + Some(b"bbb".to_vec()), + Some(b"".to_vec()), + ])), + "list_decimal" => { + let values = i64_values + .iter() + .map(|x| x.map(|x| x as i128)) + .collect::>(); + Box::new(PrimitiveArray::::from(values).to(ArrowDataType::Decimal(9, 0))) + }, + "list_decimal256" => { + let values = i64_values + .iter() + .map(|x| x.map(|x| i256(x.as_i256()))) + .collect::>(); + let array = PrimitiveArray::::from(values).to(ArrowDataType::Decimal256(9, 0)); + Box::new(array) + }, + "list_nested_i64" + | "list_nested_inner_required_i64" + | "list_nested_inner_required_required_i64" => { + Box::new(NullArray::new(ArrowDataType::Null, 1)) + }, + "struct_list_nullable" => pyarrow_nested_nullable("list_utf8"), + "list_struct_nullable" => { + let array = Utf8ViewArray::from_slice([ + Some("a"), + Some("b"), + // + Some("b"), + None, + Some("b"), + // + None, + None, + None, + // + Some("d"), + Some("d"), + Some("d"), + // + Some("e"), + ]) + .boxed(); + new_struct( + vec![array], + vec!["a".to_string()], + Some( + [ + true, true, // + true, false, true, // + true, true, true, // + true, true, true, // + true, + ] + .into(), + ), + ) + .boxed() + }, + "list_struct_list_nullable" => { + /* + [ + [{"a": ["a"]}, {"a": ["b"]}], + None, + [{"a": ["b"]}, None, {"a": ["b"]}], + [{"a": None}, {"a": None}, {"a": None}], + [], + [{"a": ["d"]}, {"a": [None]}, {"a": ["c", "d"]}], + None, + [{"a": []}], + ] + */ + let array = Utf8ViewArray::from_slice([ + Some("a"), + Some("b"), + // + Some("b"), + Some("b"), + // + Some("d"), + None, + Some("c"), + Some("d"), + ]) + .boxed(); + + let array = ListArray::::new( + ArrowDataType::LargeList(Box::new(Field::new( + "item", + array.data_type().clone(), + true, + ))), + vec![0, 1, 2, 3, 3, 4, 4, 4, 4, 5, 6, 8, 8] + .try_into() + .unwrap(), + array, + Some( + [ + true, true, true, false, true, false, false, false, true, true, true, true, + ] + .into(), + ), + ) + .boxed(); + + new_struct( + vec![array], + vec!["a".to_string()], + Some( + [ + true, true, // + true, false, true, // + true, true, true, // + true, true, true, // + true, + ] + .into(), + ), + ) + .boxed() + }, + other => unreachable!("{}", other), + }; + + match column { + "list_int64_required_required" => { + // [[0, 1], [], [2, 0, 3], [4, 5, 6], [], [7, 8, 9], [], [10]] + let data_type = + ArrowDataType::LargeList(Box::new(Field::new("item", ArrowDataType::Int64, false))); + ListArray::::new(data_type, offsets, values, None).boxed() + }, + "list_int64_optional_required" => { + // [[0, 1], [], [2, 0, 3], [4, 5, 6], [], [7, 8, 9], [], [10]] + let data_type = + ArrowDataType::LargeList(Box::new(Field::new("item", ArrowDataType::Int64, true))); + ListArray::::new(data_type, offsets, values, None).boxed() + }, + "list_nested_i64" => { + // [[0, 1]], None, [[2, None], [3]], [[4, 5], [6]], [], [[7], None, [9]], [[], [None], None], [[10]] + let data = [ + Some(vec![Some(vec![Some(0), Some(1)])]), + None, + Some(vec![Some(vec![Some(2), None]), Some(vec![Some(3)])]), + Some(vec![Some(vec![Some(4), Some(5)]), Some(vec![Some(6)])]), + Some(vec![]), + Some(vec![Some(vec![Some(7)]), None, Some(vec![Some(9)])]), + Some(vec![Some(vec![]), Some(vec![None]), None]), + Some(vec![Some(vec![Some(10)])]), + ]; + let mut a = + MutableListArray::>>::new(); + a.try_extend(data).unwrap(); + let array: ListArray = a.into(); + Box::new(array) + }, + "list_nested_inner_required_i64" => { + let data = [ + Some(vec![Some(vec![Some(0), Some(1)])]), + None, + Some(vec![Some(vec![Some(2), Some(3)]), Some(vec![Some(3)])]), + Some(vec![Some(vec![Some(4), Some(5)]), Some(vec![Some(6)])]), + Some(vec![]), + Some(vec![Some(vec![Some(7)]), None, Some(vec![Some(9)])]), + None, + Some(vec![Some(vec![Some(10)])]), + ]; + let mut a = + MutableListArray::>>::new(); + a.try_extend(data).unwrap(); + let array: ListArray = a.into(); + Box::new(array) + }, + "list_nested_inner_required_required_i64" => { + let data = [ + Some(vec![Some(vec![Some(0), Some(1)])]), + None, + Some(vec![Some(vec![Some(2), Some(3)]), Some(vec![Some(3)])]), + Some(vec![Some(vec![Some(4), Some(5)]), Some(vec![Some(6)])]), + Some(vec![]), + Some(vec![ + Some(vec![Some(7)]), + Some(vec![Some(8)]), + Some(vec![Some(9)]), + ]), + None, + Some(vec![Some(vec![Some(10)])]), + ]; + let mut a = + MutableListArray::>>::new(); + a.try_extend(data).unwrap(); + let array: ListArray = a.into(); + Box::new(array) + }, + "struct_list_nullable" => new_struct(vec![values], vec!["a".to_string()], None).boxed(), + _ => { + let field = match column { + "list_int64" => Field::new("item", ArrowDataType::Int64, true), + "list_int64_required" => Field::new("item", ArrowDataType::Int64, false), + "list_int16" => Field::new("item", ArrowDataType::Int16, true), + "list_bool" => Field::new("item", ArrowDataType::Boolean, true), + "list_utf8" => Field::new("item", ArrowDataType::Utf8View, true), + "list_large_binary" => Field::new("item", ArrowDataType::LargeBinary, true), + "list_decimal" => Field::new("item", ArrowDataType::Decimal(9, 0), true), + "list_decimal256" => Field::new("item", ArrowDataType::Decimal256(9, 0), true), + "list_struct_nullable" => Field::new("item", values.data_type().clone(), true), + "list_struct_list_nullable" => Field::new("item", values.data_type().clone(), true), + other => unreachable!("{}", other), + }; + + let validity = Some(Bitmap::from([ + true, false, true, true, true, true, false, true, + ])); + // [0, 2, 2, 5, 8, 8, 11, 11, 12] + // [[a1, a2], None, [a3, a4, a5], [a6, a7, a8], [], [a9, a10, a11], None, [a12]] + let data_type = ArrowDataType::LargeList(Box::new(field)); + ListArray::::new(data_type, offsets, values, validity).boxed() + }, + } +} + +pub fn pyarrow_nullable(column: &str) -> Box { + let i64_values = &[ + Some(-256), + Some(-1), + None, + Some(3), + None, + Some(5), + Some(6), + Some(7), + None, + Some(9), + ]; + let u32_values = &[ + Some(0), + Some(1), + None, + Some(3), + None, + Some(5), + Some(6), + Some(7), + None, + Some(9), + ]; + + match column { + "int64" => Box::new(PrimitiveArray::::from(i64_values)), + "float64" => Box::new(PrimitiveArray::::from(&[ + Some(0.0), + Some(1.0), + None, + Some(3.0), + None, + Some(5.0), + Some(6.0), + Some(7.0), + None, + Some(9.0), + ])), + "string" => Box::new(Utf8ViewArray::from_slice([ + Some("Hello".to_string()), + None, + Some("aa".to_string()), + Some("".to_string()), + None, + Some("abc".to_string()), + None, + None, + Some("def".to_string()), + Some("aaa".to_string()), + ])), + "bool" => Box::new(BooleanArray::from([ + Some(true), + None, + Some(false), + Some(false), + None, + Some(true), + None, + None, + Some(true), + Some(true), + ])), + "timestamp_ms" => Box::new( + PrimitiveArray::::from_iter(u32_values.iter().map(|x| x.map(|x| x as i64))) + .to(ArrowDataType::Timestamp(TimeUnit::Millisecond, None)), + ), + "uint32" => Box::new(PrimitiveArray::::from(u32_values)), + "int32_dict" => { + let keys = PrimitiveArray::::from([Some(0), Some(1), None, Some(1)]); + let values = Box::new(PrimitiveArray::::from_slice([10, 200])); + Box::new(DictionaryArray::try_from_keys(keys, values).unwrap()) + }, + "timestamp_us" => Box::new( + PrimitiveArray::::from(i64_values) + .to(ArrowDataType::Timestamp(TimeUnit::Microsecond, None)), + ), + "timestamp_s" => Box::new( + PrimitiveArray::::from(i64_values) + .to(ArrowDataType::Timestamp(TimeUnit::Second, None)), + ), + "timestamp_s_utc" => Box::new(PrimitiveArray::::from(i64_values).to( + ArrowDataType::Timestamp(TimeUnit::Second, Some("UTC".to_string())), + )), + _ => unreachable!(), + } +} + +pub fn pyarrow_nullable_statistics(column: &str) -> Statistics { + match column { + "int64" => Statistics { + distinct_count: UInt64Array::from([None]).boxed(), + null_count: UInt64Array::from([Some(3)]).boxed(), + min_value: Box::new(Int64Array::from_slice([-256])), + max_value: Box::new(Int64Array::from_slice([9])), + }, + "float64" => Statistics { + distinct_count: UInt64Array::from([None]).boxed(), + null_count: UInt64Array::from([Some(3)]).boxed(), + min_value: Box::new(Float64Array::from_slice([0.0])), + max_value: Box::new(Float64Array::from_slice([9.0])), + }, + "string" => Statistics { + distinct_count: UInt64Array::from([None]).boxed(), + null_count: UInt64Array::from([Some(4)]).boxed(), + min_value: Box::new(Utf8ViewArray::from_slice([Some("")])), + max_value: Box::new(Utf8ViewArray::from_slice([Some("def")])), + }, + "bool" => Statistics { + distinct_count: UInt64Array::from([None]).boxed(), + null_count: UInt64Array::from([Some(4)]).boxed(), + min_value: Box::new(BooleanArray::from_slice([false])), + max_value: Box::new(BooleanArray::from_slice([true])), + }, + "timestamp_ms" => Statistics { + distinct_count: UInt64Array::from([None]).boxed(), + null_count: UInt64Array::from([Some(3)]).boxed(), + min_value: Box::new( + Int64Array::from_slice([0]) + .to(ArrowDataType::Timestamp(TimeUnit::Millisecond, None)), + ), + max_value: Box::new( + Int64Array::from_slice([9]) + .to(ArrowDataType::Timestamp(TimeUnit::Millisecond, None)), + ), + }, + "uint32" => Statistics { + distinct_count: UInt64Array::from([None]).boxed(), + null_count: UInt64Array::from([Some(3)]).boxed(), + min_value: Box::new(UInt32Array::from_slice([0])), + max_value: Box::new(UInt32Array::from_slice([9])), + }, + "int32_dict" => { + let new_dict = |array: Box| -> Box { + Box::new(DictionaryArray::try_from_keys(vec![Some(0)].into(), array).unwrap()) + }; + + Statistics { + distinct_count: UInt64Array::from([None]).boxed(), + null_count: UInt64Array::from([Some(1)]).boxed(), + min_value: new_dict(Box::new(Int32Array::from_slice([10]))), + max_value: new_dict(Box::new(Int32Array::from_slice([200]))), + } + }, + "timestamp_us" => Statistics { + distinct_count: UInt64Array::from([None]).boxed(), + null_count: UInt64Array::from([Some(3)]).boxed(), + min_value: Box::new( + Int64Array::from_slice([-256]) + .to(ArrowDataType::Timestamp(TimeUnit::Microsecond, None)), + ), + max_value: Box::new( + Int64Array::from_slice([9]) + .to(ArrowDataType::Timestamp(TimeUnit::Microsecond, None)), + ), + }, + "timestamp_s" => Statistics { + distinct_count: UInt64Array::from([None]).boxed(), + null_count: UInt64Array::from([Some(3)]).boxed(), + min_value: Box::new( + Int64Array::from_slice([-256]).to(ArrowDataType::Timestamp(TimeUnit::Second, None)), + ), + max_value: Box::new( + Int64Array::from_slice([9]).to(ArrowDataType::Timestamp(TimeUnit::Second, None)), + ), + }, + "timestamp_s_utc" => Statistics { + distinct_count: UInt64Array::from([None]).boxed(), + null_count: UInt64Array::from([Some(3)]).boxed(), + min_value: Box::new(Int64Array::from_slice([-256]).to(ArrowDataType::Timestamp( + TimeUnit::Second, + Some("UTC".to_string()), + ))), + max_value: Box::new(Int64Array::from_slice([9]).to(ArrowDataType::Timestamp( + TimeUnit::Second, + Some("UTC".to_string()), + ))), + }, + _ => unreachable!(), + } +} + +// these values match the values in `integration` +pub fn pyarrow_required(column: &str) -> Box { + let i64_values = &[ + Some(-256), + Some(-1), + Some(2), + Some(3), + Some(4), + Some(5), + Some(6), + Some(7), + Some(8), + Some(9), + ]; + + match column { + "int64" => Box::new(PrimitiveArray::::from(i64_values)), + "bool" => Box::new(BooleanArray::from_slice([ + true, true, false, false, false, true, true, true, true, true, + ])), + "string" => Box::new(Utf8ViewArray::from_slice([ + Some("Hello"), + Some("bbb"), + Some("aa"), + Some(""), + Some("bbb"), + Some("abc"), + Some("bbb"), + Some("bbb"), + Some("def"), + Some("aaa"), + ])), + _ => unreachable!(), + } +} + +pub fn pyarrow_required_statistics(column: &str) -> Statistics { + let mut s = pyarrow_nullable_statistics(column); + s.null_count = UInt64Array::from([Some(0)]).boxed(); + s +} + +pub fn pyarrow_nested_nullable_statistics(column: &str) -> Statistics { + let new_list = |array: Box, nullable: bool| { + ListArray::::new( + ArrowDataType::LargeList(Box::new(Field::new( + "item", + array.data_type().clone(), + nullable, + ))), + vec![0, array.len() as i64].try_into().unwrap(), + array, + None, + ) + }; + + match column { + "list_int16" => Statistics { + distinct_count: new_list(UInt64Array::from([None]).boxed(), true).boxed(), + null_count: new_list(UInt64Array::from([Some(1)]).boxed(), true).boxed(), + min_value: new_list(Box::new(Int16Array::from_slice([0])), true).boxed(), + max_value: new_list(Box::new(Int16Array::from_slice([10])), true).boxed(), + }, + "list_bool" => Statistics { + distinct_count: new_list(UInt64Array::from([None]).boxed(), true).boxed(), + null_count: new_list(UInt64Array::from([Some(1)]).boxed(), true).boxed(), + min_value: new_list(Box::new(BooleanArray::from_slice([false])), true).boxed(), + max_value: new_list(Box::new(BooleanArray::from_slice([true])), true).boxed(), + }, + "list_utf8" => Statistics { + distinct_count: new_list(UInt64Array::from([None]).boxed(), true).boxed(), + null_count: new_list(UInt64Array::from([Some(1)]).boxed(), true).boxed(), + min_value: new_list(Box::new(Utf8ViewArray::from_slice([Some("")])), true).boxed(), + max_value: new_list(Box::new(Utf8ViewArray::from_slice([Some("ccc")])), true).boxed(), + }, + "list_large_binary" => Statistics { + distinct_count: new_list(UInt64Array::from([None]).boxed(), true).boxed(), + null_count: new_list(UInt64Array::from([Some(1)]).boxed(), true).boxed(), + min_value: new_list(Box::new(BinaryArray::::from_slice([b""])), true).boxed(), + max_value: new_list(Box::new(BinaryArray::::from_slice([b"ccc"])), true).boxed(), + }, + "list_decimal" => Statistics { + distinct_count: new_list(UInt64Array::from([None]).boxed(), true).boxed(), + null_count: new_list(UInt64Array::from([Some(1)]).boxed(), true).boxed(), + min_value: new_list( + Box::new(Int128Array::from_slice([0]).to(ArrowDataType::Decimal(9, 0))), + true, + ) + .boxed(), + max_value: new_list( + Box::new(Int128Array::from_slice([10]).to(ArrowDataType::Decimal(9, 0))), + true, + ) + .boxed(), + }, + "list_decimal256" => Statistics { + distinct_count: new_list(UInt64Array::from([None]).boxed(), true).boxed(), + null_count: new_list(UInt64Array::from([Some(1)]).boxed(), true).boxed(), + min_value: new_list( + Box::new( + Int256Array::from_slice([i256(0.as_i256())]) + .to(ArrowDataType::Decimal256(9, 0)), + ), + true, + ) + .boxed(), + max_value: new_list( + Box::new( + Int256Array::from_slice([i256(10.as_i256())]) + .to(ArrowDataType::Decimal256(9, 0)), + ), + true, + ) + .boxed(), + }, + "list_int64" => Statistics { + distinct_count: new_list(UInt64Array::from([None]).boxed(), true).boxed(), + null_count: new_list(UInt64Array::from([Some(1)]).boxed(), true).boxed(), + min_value: new_list(Box::new(Int64Array::from_slice([0])), true).boxed(), + max_value: new_list(Box::new(Int64Array::from_slice([10])), true).boxed(), + }, + "list_int64_required" => Statistics { + distinct_count: new_list(UInt64Array::from([None]).boxed(), true).boxed(), + null_count: new_list(UInt64Array::from([Some(1)]).boxed(), true).boxed(), + min_value: new_list(Box::new(Int64Array::from_slice([0])), false).boxed(), + max_value: new_list(Box::new(Int64Array::from_slice([10])), false).boxed(), + }, + "list_int64_required_required" | "list_int64_optional_required" => Statistics { + distinct_count: new_list(UInt64Array::from([None]).boxed(), false).boxed(), + null_count: new_list(UInt64Array::from([Some(0)]).boxed(), false).boxed(), + min_value: new_list(Box::new(Int64Array::from_slice([0])), false).boxed(), + max_value: new_list(Box::new(Int64Array::from_slice([10])), false).boxed(), + }, + "list_nested_i64" => Statistics { + distinct_count: new_list(UInt64Array::from([None]).boxed(), true).boxed(), + null_count: new_list(UInt64Array::from([Some(2)]).boxed(), true).boxed(), + min_value: new_list( + new_list(Box::new(Int64Array::from_slice([0])), true).boxed(), + true, + ) + .boxed(), + max_value: new_list( + new_list(Box::new(Int64Array::from_slice([10])), true).boxed(), + true, + ) + .boxed(), + }, + "list_nested_inner_required_required_i64" => Statistics { + distinct_count: UInt64Array::from([None]).boxed(), + null_count: UInt64Array::from([Some(0)]).boxed(), + min_value: new_list( + new_list(Box::new(Int64Array::from_slice([0])), true).boxed(), + true, + ) + .boxed(), + max_value: new_list( + new_list(Box::new(Int64Array::from_slice([10])), true).boxed(), + true, + ) + .boxed(), + }, + "list_nested_inner_required_i64" => Statistics { + distinct_count: UInt64Array::from([None]).boxed(), + null_count: UInt64Array::from([Some(0)]).boxed(), + min_value: new_list( + new_list(Box::new(Int64Array::from_slice([0])), true).boxed(), + true, + ) + .boxed(), + max_value: new_list( + new_list(Box::new(Int64Array::from_slice([10])), true).boxed(), + true, + ) + .boxed(), + }, + "list_struct_nullable" => Statistics { + distinct_count: new_list( + new_struct( + vec![UInt64Array::from([None]).boxed()], + vec!["a".to_string()], + None, + ) + .boxed(), + true, + ) + .boxed(), + null_count: new_list( + new_struct( + vec![UInt64Array::from([Some(4)]).boxed()], + vec!["a".to_string()], + None, + ) + .boxed(), + true, + ) + .boxed(), + min_value: new_list( + new_struct( + vec![Utf8ViewArray::from_slice([Some("a")]).boxed()], + vec!["a".to_string()], + None, + ) + .boxed(), + true, + ) + .boxed(), + max_value: new_list( + new_struct( + vec![Utf8ViewArray::from_slice([Some("e")]).boxed()], + vec!["a".to_string()], + None, + ) + .boxed(), + true, + ) + .boxed(), + }, + "list_struct_list_nullable" => Statistics { + distinct_count: new_list( + new_struct( + vec![new_list(UInt64Array::from([None]).boxed(), true).boxed()], + vec!["a".to_string()], + None, + ) + .boxed(), + true, + ) + .boxed(), + null_count: new_list( + new_struct( + vec![new_list(UInt64Array::from([Some(1)]).boxed(), true).boxed()], + vec!["a".to_string()], + None, + ) + .boxed(), + true, + ) + .boxed(), + min_value: new_list( + new_struct( + vec![new_list(Utf8ViewArray::from_slice([Some("a")]).boxed(), true).boxed()], + vec!["a".to_string()], + None, + ) + .boxed(), + true, + ) + .boxed(), + max_value: new_list( + new_struct( + vec![new_list(Utf8ViewArray::from_slice([Some("d")]).boxed(), true).boxed()], + vec!["a".to_string()], + None, + ) + .boxed(), + true, + ) + .boxed(), + }, + "struct_list_nullable" => Statistics { + distinct_count: new_struct( + vec![new_list(UInt64Array::from([None]).boxed(), true).boxed()], + vec!["a".to_string()], + None, + ) + .boxed(), + null_count: new_struct( + vec![new_list(UInt64Array::from([Some(1)]).boxed(), true).boxed()], + vec!["a".to_string()], + None, + ) + .boxed(), + min_value: new_struct( + vec![new_list(Utf8ViewArray::from_slice([Some("")]).boxed(), true).boxed()], + vec!["a".to_string()], + None, + ) + .boxed(), + max_value: new_struct( + vec![new_list(Utf8ViewArray::from_slice([Some("ccc")]).boxed(), true).boxed()], + vec!["a".to_string()], + None, + ) + .boxed(), + }, + other => todo!("{}", other), + } +} + +pub fn pyarrow_nested_edge_statistics(column: &str) -> Statistics { + let new_list = |array: Box| { + ListArray::::new( + ArrowDataType::LargeList(Box::new(Field::new( + "item", + array.data_type().clone(), + true, + ))), + vec![0, array.len() as i64].try_into().unwrap(), + array, + None, + ) + }; + + let new_struct = |arrays: Vec>, names: Vec| { + let fields = names + .into_iter() + .zip(arrays.iter()) + .map(|(n, a)| Field::new(n, a.data_type().clone(), true)) + .collect(); + StructArray::new(ArrowDataType::Struct(fields), arrays, None) + }; + + let names = vec!["f1".to_string()]; + + match column { + "simple" => Statistics { + distinct_count: new_list(UInt64Array::from([None]).boxed()).boxed(), + null_count: new_list(UInt64Array::from([Some(0)]).boxed()).boxed(), + min_value: new_list(Box::new(Int64Array::from([Some(0)]))).boxed(), + max_value: new_list(Box::new(Int64Array::from([Some(1)]))).boxed(), + }, + "null" | "empty" => Statistics { + distinct_count: new_list(UInt64Array::from([None]).boxed()).boxed(), + null_count: new_list(UInt64Array::from([Some(0)]).boxed()).boxed(), + min_value: new_list(Box::new(Int64Array::from([None]))).boxed(), + max_value: new_list(Box::new(Int64Array::from([None]))).boxed(), + }, + "struct_list_nullable" => Statistics { + distinct_count: new_struct( + vec![new_list(Box::new(UInt64Array::from([None]))).boxed()], + names.clone(), + ) + .boxed(), + null_count: new_struct( + vec![new_list(Box::new(UInt64Array::from([Some(1)]))).boxed()], + names.clone(), + ) + .boxed(), + min_value: Box::new(new_struct( + vec![new_list(Box::new(Utf8ViewArray::from_slice([Some("a")]))).boxed()], + names.clone(), + )), + max_value: Box::new(new_struct( + vec![new_list(Box::new(Utf8ViewArray::from_slice([Some("c")]))).boxed()], + names, + )), + }, + "list_struct_list_nullable" => Statistics { + distinct_count: new_list( + new_struct( + vec![new_list(Box::new(UInt64Array::from([None]))).boxed()], + names.clone(), + ) + .boxed(), + ) + .boxed(), + null_count: new_list( + new_struct( + vec![new_list(Box::new(UInt64Array::from([Some(1)]))).boxed()], + names.clone(), + ) + .boxed(), + ) + .boxed(), + min_value: new_list(Box::new(new_struct( + vec![new_list(Box::new(Utf8ViewArray::from_slice([Some("a")]))).boxed()], + names.clone(), + ))) + .boxed(), + max_value: new_list(Box::new(new_struct( + vec![new_list(Box::new(Utf8ViewArray::from_slice([Some("c")]))).boxed()], + names, + ))) + .boxed(), + }, + _ => unreachable!(), + } +} + +pub fn pyarrow_struct(column: &str) -> Box { + let boolean = [ + Some(true), + None, + Some(false), + Some(false), + None, + Some(true), + None, + None, + Some(true), + Some(true), + ]; + let boolean = BooleanArray::from(boolean).boxed(); + + let string = [ + Some("Hello"), + None, + Some("aa"), + Some(""), + None, + Some("abc"), + None, + None, + Some("def"), + Some("aaa"), + ]; + let string = Utf8ViewArray::from_slice(string).boxed(); + + let mask = [true, true, false, true, true, true, true, true, true, true]; + + let fields = vec![ + Field::new("f1", ArrowDataType::Utf8View, true), + Field::new("f2", ArrowDataType::Boolean, true), + ]; + match column { + "struct" => { + StructArray::new(ArrowDataType::Struct(fields), vec![string, boolean], None).boxed() + }, + "struct_nullable" => { + let values = vec![string, boolean]; + StructArray::new(ArrowDataType::Struct(fields), values, Some(mask.into())).boxed() + }, + "struct_struct" => { + let struct_ = pyarrow_struct("struct"); + Box::new(StructArray::new( + ArrowDataType::Struct(vec![ + Field::new("f1", ArrowDataType::Struct(fields), true), + Field::new("f2", ArrowDataType::Boolean, true), + ]), + vec![struct_, boolean], + None, + )) + }, + "struct_struct_nullable" => { + let struct_ = pyarrow_struct("struct"); + Box::new(StructArray::new( + ArrowDataType::Struct(vec![ + Field::new("f1", ArrowDataType::Struct(fields), true), + Field::new("f2", ArrowDataType::Boolean, true), + ]), + vec![struct_, boolean], + Some(mask.into()), + )) + }, + _ => todo!(), + } +} + +pub fn pyarrow_struct_statistics(column: &str) -> Statistics { + let new_struct = + |arrays: Vec>, names: Vec| new_struct(arrays, names, None); + + let names = vec!["f1".to_string(), "f2".to_string()]; + + match column { + "struct" | "struct_nullable" => Statistics { + distinct_count: new_struct( + vec![ + Box::new(UInt64Array::from([None])), + Box::new(UInt64Array::from([None])), + ], + names.clone(), + ) + .boxed(), + null_count: new_struct( + vec![ + Box::new(UInt64Array::from([Some(4)])), + Box::new(UInt64Array::from([Some(4)])), + ], + names.clone(), + ) + .boxed(), + min_value: Box::new(new_struct( + vec![ + Box::new(Utf8ViewArray::from_slice([Some("")])), + Box::new(BooleanArray::from_slice([false])), + ], + names.clone(), + )), + max_value: Box::new(new_struct( + vec![ + Box::new(Utf8ViewArray::from_slice([Some("def")])), + Box::new(BooleanArray::from_slice([true])), + ], + names, + )), + }, + "struct_struct" => Statistics { + distinct_count: new_struct( + vec![ + new_struct( + vec![ + Box::new(UInt64Array::from([None])), + Box::new(UInt64Array::from([None])), + ], + names.clone(), + ) + .boxed(), + UInt64Array::from([None]).boxed(), + ], + names.clone(), + ) + .boxed(), + null_count: new_struct( + vec![ + new_struct( + vec![ + Box::new(UInt64Array::from([Some(4)])), + Box::new(UInt64Array::from([Some(4)])), + ], + names.clone(), + ) + .boxed(), + UInt64Array::from([Some(4)]).boxed(), + ], + names.clone(), + ) + .boxed(), + min_value: new_struct( + vec![ + new_struct( + vec![ + Utf8ViewArray::from_slice([Some("")]).boxed(), + BooleanArray::from_slice([false]).boxed(), + ], + names.clone(), + ) + .boxed(), + BooleanArray::from_slice([false]).boxed(), + ], + names.clone(), + ) + .boxed(), + max_value: new_struct( + vec![ + new_struct( + vec![ + Utf8ViewArray::from_slice([Some("def")]).boxed(), + BooleanArray::from_slice([true]).boxed(), + ], + names.clone(), + ) + .boxed(), + BooleanArray::from_slice([true]).boxed(), + ], + names, + ) + .boxed(), + }, + "struct_struct_nullable" => Statistics { + distinct_count: new_struct( + vec![ + new_struct( + vec![ + Box::new(UInt64Array::from([None])), + Box::new(UInt64Array::from([None])), + ], + names.clone(), + ) + .boxed(), + UInt64Array::from([None]).boxed(), + ], + names.clone(), + ) + .boxed(), + null_count: new_struct( + vec![ + new_struct( + vec![ + Box::new(UInt64Array::from([Some(5)])), + Box::new(UInt64Array::from([Some(5)])), + ], + names.clone(), + ) + .boxed(), + UInt64Array::from([Some(5)]).boxed(), + ], + names.clone(), + ) + .boxed(), + min_value: new_struct( + vec![ + new_struct( + vec![ + Utf8ViewArray::from_slice([Some("")]).boxed(), + BooleanArray::from_slice([false]).boxed(), + ], + names.clone(), + ) + .boxed(), + BooleanArray::from_slice([false]).boxed(), + ], + names.clone(), + ) + .boxed(), + max_value: new_struct( + vec![ + new_struct( + vec![ + Utf8ViewArray::from_slice([Some("def")]).boxed(), + BooleanArray::from_slice([true]).boxed(), + ], + names.clone(), + ) + .boxed(), + BooleanArray::from_slice([true]).boxed(), + ], + names, + ) + .boxed(), + }, + _ => todo!(), + } +} + +fn integration_write( + schema: &ArrowSchema, + chunks: &[Chunk>], +) -> PolarsResult> { + let options = WriteOptions { + write_statistics: true, + compression: CompressionOptions::Uncompressed, + version: Version::V1, + data_pagesize_limit: None, + }; + + let encodings = schema + .fields + .iter() + .map(|f| { + transverse(&f.data_type, |x| { + if let ArrowDataType::Dictionary(..) = x { + Encoding::RleDictionary + } else { + Encoding::Plain + } + }) + }) + .collect(); + + let row_groups = + RowGroupIterator::try_new(chunks.iter().cloned().map(Ok), schema, options, encodings)?; + + let writer = Cursor::new(vec![]); + + let mut writer = FileWriter::try_new(writer, schema.clone(), options)?; + + for group in row_groups { + writer.write(group?)?; + } + writer.end(None)?; + + Ok(writer.into_inner().into_inner()) +} + +type IntegrationRead = (ArrowSchema, Vec>>); + +fn integration_read(data: &[u8], limit: Option) -> PolarsResult { + let mut reader = Cursor::new(data); + let metadata = p_read::read_metadata(&mut reader)?; + let schema = p_read::infer_schema(&metadata)?; + + for (field, row_group) in schema.fields.iter().zip(metadata.row_groups.iter()) { + let mut _statistics = deserialize(field, row_group)?; + } + + let reader = p_read::FileReader::new( + Cursor::new(data), + metadata.row_groups, + schema.clone(), + None, + limit, + None, + ); + + let batches = reader.collect::>>()?; + + Ok((schema, batches)) +} + +fn generic_data() -> PolarsResult<(ArrowSchema, Chunk>)> { + let array1 = PrimitiveArray::::from([Some(1), None, Some(2)]) + .to(ArrowDataType::Duration(TimeUnit::Second)); + let array2 = Utf8ViewArray::from_slice([Some("a"), None, Some("bb")]); + + let indices = PrimitiveArray::from_values((0..3u64).map(|x| x % 2)); + let values = PrimitiveArray::from_slice([1.0f32, 3.0]).boxed(); + let array3 = DictionaryArray::try_from_keys(indices.clone(), values).unwrap(); + + let array4 = BinaryViewArray::from_slice([Some(b"ab"), Some(b"aa"), Some(b"ac")]); + + let values = PrimitiveArray::from_slice([1i16, 3]).boxed(); + let array6 = DictionaryArray::try_from_keys(indices.clone(), values).unwrap(); + + let values = PrimitiveArray::from_slice([1i64, 3]) + .to(ArrowDataType::Timestamp( + TimeUnit::Millisecond, + Some("UTC".to_string()), + )) + .boxed(); + let array7 = DictionaryArray::try_from_keys(indices.clone(), values).unwrap(); + + let values = PrimitiveArray::from_slice([1.0f64, 3.0]).boxed(); + let array8 = DictionaryArray::try_from_keys(indices.clone(), values).unwrap(); + + let values = PrimitiveArray::from_slice([1u8, 3]).boxed(); + let array9 = DictionaryArray::try_from_keys(indices.clone(), values).unwrap(); + + let values = PrimitiveArray::from_slice([1u16, 3]).boxed(); + let array10 = DictionaryArray::try_from_keys(indices.clone(), values).unwrap(); + + let values = PrimitiveArray::from_slice([1u32, 3]).boxed(); + let array11 = DictionaryArray::try_from_keys(indices.clone(), values).unwrap(); + + let values = PrimitiveArray::from_slice([1u64, 3]).boxed(); + let array12 = DictionaryArray::try_from_keys(indices, values).unwrap(); + + let array13 = PrimitiveArray::::from_slice([1, 2, 3]) + .to(ArrowDataType::Interval(IntervalUnit::YearMonth)); + + let schema = ArrowSchema::from(vec![ + Field::new("a1", array1.data_type().clone(), true), + Field::new("a2", array2.data_type().clone(), true), + Field::new("a3", array3.data_type().clone(), true), + Field::new("a4", array4.data_type().clone(), true), + Field::new("a6", array6.data_type().clone(), true), + Field::new("a7", array7.data_type().clone(), true), + Field::new("a8", array8.data_type().clone(), true), + Field::new("a9", array9.data_type().clone(), true), + Field::new("a10", array10.data_type().clone(), true), + Field::new("a11", array11.data_type().clone(), true), + Field::new("a12", array12.data_type().clone(), true), + Field::new("a13", array13.data_type().clone(), true), + ]); + let chunk = Chunk::try_new(vec![ + array1.boxed(), + array2.boxed(), + array3.boxed(), + array4.boxed(), + array6.boxed(), + array7.boxed(), + array8.boxed(), + array9.boxed(), + array10.boxed(), + array11.boxed(), + array12.boxed(), + array13.boxed(), + ])?; + + Ok((schema, chunk)) +} + +fn assert_roundtrip( + schema: ArrowSchema, + chunk: Chunk>, + limit: Option, +) -> PolarsResult<()> { + let r = integration_write(&schema, &[chunk.clone()])?; + + let (new_schema, new_chunks) = integration_read(&r, limit)?; + + let expected = if let Some(limit) = limit { + let expected = chunk + .into_arrays() + .into_iter() + .map(|x| x.sliced(0, limit)) + .collect::>(); + Chunk::new(expected) + } else { + chunk + }; + + assert_eq!(new_schema, schema); + assert_eq!(new_chunks, vec![expected]); + Ok(()) +} + +/// Tests that when arrow-specific types (Duration and LargeUtf8) are written to parquet, we can roundtrip its +/// logical types. +#[test] +fn arrow_type() -> PolarsResult<()> { + let (schema, chunk) = generic_data()?; + assert_roundtrip(schema, chunk, None) +} + +fn data>( + mut iter: I, + inner_is_nullable: bool, +) -> Box { + // [[0, 1], [], [2, 0, 3], [4, 5, 6], [], [7, 8, 9], [], [10]] + let data = vec![ + Some(vec![Some(iter.next().unwrap()), Some(iter.next().unwrap())]), + Some(vec![]), + Some(vec![ + Some(iter.next().unwrap()), + Some(iter.next().unwrap()), + Some(iter.next().unwrap()), + ]), + Some(vec![ + Some(iter.next().unwrap()), + Some(iter.next().unwrap()), + Some(iter.next().unwrap()), + ]), + Some(vec![]), + Some(vec![ + Some(iter.next().unwrap()), + Some(iter.next().unwrap()), + Some(iter.next().unwrap()), + ]), + Some(vec![]), + Some(vec![Some(iter.next().unwrap())]), + ]; + let mut array = MutableListArray::::new_with_field( + MutablePrimitiveArray::::new(), + "item", + inner_is_nullable, + ); + array.try_extend(data).unwrap(); + array.into_box() +} + +fn assert_array_roundtrip( + is_nullable: bool, + array: Box, + limit: Option, +) -> PolarsResult<()> { + let schema = ArrowSchema::from(vec![Field::new( + "a1", + array.data_type().clone(), + is_nullable, + )]); + let chunk = Chunk::try_new(vec![array])?; + + assert_roundtrip(schema, chunk, limit) +} + +fn test_list_array_required_required(limit: Option) -> PolarsResult<()> { + assert_array_roundtrip(false, data(0..12i8, false), limit)?; + assert_array_roundtrip(false, data(0..12i16, false), limit)?; + assert_array_roundtrip(false, data(0..12i64, false), limit)?; + assert_array_roundtrip(false, data(0..12i64, false), limit)?; + assert_array_roundtrip(false, data(0..12u8, false), limit)?; + assert_array_roundtrip(false, data(0..12u16, false), limit)?; + assert_array_roundtrip(false, data(0..12u32, false), limit)?; + assert_array_roundtrip(false, data(0..12u64, false), limit)?; + assert_array_roundtrip(false, data((0..12).map(|x| (x as f32) * 1.0), false), limit)?; + assert_array_roundtrip( + false, + data((0..12).map(|x| (x as f64) * 1.0f64), false), + limit, + ) +} + +#[test] +fn list_array_required_required() -> PolarsResult<()> { + test_list_array_required_required(None) +} + +#[test] +fn list_array_optional_optional() -> PolarsResult<()> { + assert_array_roundtrip(true, data(0..12, true), None) +} + +#[test] +fn list_array_required_optional() -> PolarsResult<()> { + assert_array_roundtrip(true, data(0..12, false), None) +} + +#[test] +fn list_array_optional_required() -> PolarsResult<()> { + assert_array_roundtrip(false, data(0..12, true), None) +} + +#[test] +fn list_slice() -> PolarsResult<()> { + let data = vec![ + Some(vec![None, Some(2)]), + Some(vec![Some(3), Some(4)]), + Some(vec![Some(5), Some(6)]), + ]; + let mut array = MutableListArray::::new_with_field( + MutablePrimitiveArray::::new(), + "item", + true, + ); + array.try_extend(data).unwrap(); + let a: ListArray = array.into(); + let a = a.sliced(2, 1); + assert_array_roundtrip(false, a.boxed(), None) +} + +#[test] +fn struct_slice() -> PolarsResult<()> { + let a = pyarrow_nested_nullable("struct_list_nullable"); + + let a = a.sliced(2, 1); + assert_array_roundtrip(true, a, None) +} + +#[test] +fn list_struct_slice() -> PolarsResult<()> { + let a = pyarrow_nested_nullable("list_struct_nullable"); + + let a = a.sliced(2, 1); + assert_array_roundtrip(true, a, None) +} + +#[test] +fn list_int_nullable() -> PolarsResult<()> { + let data = vec![ + Some(vec![Some(1)]), + None, + Some(vec![None, Some(2)]), + Some(vec![]), + Some(vec![Some(3)]), + None, + ]; + let mut array = MutableListArray::::new_with_field( + MutablePrimitiveArray::::new(), + "item", + true, + ); + array.try_extend(data).unwrap(); + assert_array_roundtrip(true, array.into_box(), None) +} + +#[test] +fn limit() -> PolarsResult<()> { + let (schema, chunk) = generic_data()?; + assert_roundtrip(schema, chunk, Some(2)) +} + +#[test] +fn limit_list() -> PolarsResult<()> { + test_list_array_required_required(Some(2)) +} + +fn nested_dict_data( + data_type: ArrowDataType, +) -> PolarsResult<(ArrowSchema, Chunk>)> { + let values = match data_type { + ArrowDataType::Float32 => PrimitiveArray::from_slice([1.0f32, 3.0]).boxed(), + ArrowDataType::Utf8View => Utf8ViewArray::from_slice([Some("a"), Some("b")]).boxed(), + _ => unreachable!(), + }; + + let indices = PrimitiveArray::from_values((0..3u64).map(|x| x % 2)); + let values = DictionaryArray::try_from_keys(indices, values).unwrap(); + let values = LargeListArray::try_new( + ArrowDataType::LargeList(Box::new(Field::new( + "item", + values.data_type().clone(), + false, + ))), + vec![0i64, 0, 0, 2, 3].try_into().unwrap(), + values.boxed(), + Some([true, false, true, true].into()), + )?; + + let schema = ArrowSchema::from(vec![Field::new("c1", values.data_type().clone(), true)]); + let chunk = Chunk::try_new(vec![values.boxed()])?; + + Ok((schema, chunk)) +} + +#[test] +fn nested_dict() -> PolarsResult<()> { + let (schema, chunk) = nested_dict_data(ArrowDataType::Float32)?; + + assert_roundtrip(schema, chunk, None) +} + +#[test] +fn nested_dict_utf8() -> PolarsResult<()> { + let (schema, chunk) = nested_dict_data(ArrowDataType::Utf8View)?; + + assert_roundtrip(schema, chunk, None) +} + +#[test] +fn nested_dict_limit() -> PolarsResult<()> { + let (schema, chunk) = nested_dict_data(ArrowDataType::Float32)?; + + assert_roundtrip(schema, chunk, Some(2)) +} + +#[test] +fn filter_chunk() -> PolarsResult<()> { + let chunk1 = Chunk::new(vec![PrimitiveArray::from_slice([1i16, 3]).boxed()]); + let chunk2 = Chunk::new(vec![PrimitiveArray::from_slice([2i16, 4]).boxed()]); + let schema = ArrowSchema::from(vec![Field::new("c1", ArrowDataType::Int16, true)]); + + let r = integration_write(&schema, &[chunk1.clone(), chunk2.clone()])?; + + let mut reader = Cursor::new(r); + + let metadata = p_read::read_metadata(&mut reader)?; + + let new_schema = p_read::infer_schema(&metadata)?; + assert_eq!(new_schema, schema); + + // select chunk 1 + let row_groups = metadata + .row_groups + .into_iter() + .enumerate() + .filter(|(index, _)| *index == 0) + .map(|(_, row_group)| row_group) + .collect(); + + let reader = p_read::FileReader::new(reader, row_groups, schema, None, None, None); + + let new_chunks = reader.collect::>>()?; + + assert_eq!(new_chunks, vec![chunk1]); + Ok(()) +} diff --git a/crates/polars/tests/it/io/parquet/arrow/read.rs b/crates/polars/tests/it/io/parquet/arrow/read.rs new file mode 100644 index 0000000000000..1767c63da388a --- /dev/null +++ b/crates/polars/tests/it/io/parquet/arrow/read.rs @@ -0,0 +1,162 @@ +use std::path::PathBuf; + +use polars_parquet::arrow::read::*; + +use super::*; +#[cfg(feature = "parquet")] +#[test] +fn all_types() -> PolarsResult<()> { + let dir = env!("CARGO_MANIFEST_DIR"); + let path = PathBuf::from(dir).join("../../docs/data/alltypes_plain.parquet"); + + let mut reader = std::fs::File::open(path)?; + + let metadata = read_metadata(&mut reader)?; + let schema = infer_schema(&metadata)?; + let reader = FileReader::new(reader, metadata.row_groups, schema, None, None, None); + + let batches = reader.collect::>>()?; + assert_eq!(batches.len(), 1); + + let result = batches[0].columns()[0] + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(result, &Int32Array::from_slice([4, 5, 6, 7, 2, 3, 0, 1])); + + let result = batches[0].columns()[6] + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + result, + &Float32Array::from_slice([0.0, 1.1, 0.0, 1.1, 0.0, 1.1, 0.0, 1.1]) + ); + + let result = batches[0].columns()[9] + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + result, + &BinaryViewArray::from_slice_values([[48], [49], [48], [49], [48], [49], [48], [49]]) + ); + + Ok(()) +} + +#[cfg(feature = "parquet")] +#[test] +fn all_types_chunked() -> PolarsResult<()> { + // this has one batch with 8 elements + let dir = env!("CARGO_MANIFEST_DIR"); + let path = PathBuf::from(dir).join("../../docs/data/alltypes_plain.parquet"); + let mut reader = std::fs::File::open(path)?; + + let metadata = read_metadata(&mut reader)?; + let schema = infer_schema(&metadata)?; + // chunk it in 5 (so, (5,3)) + let reader = FileReader::new(reader, metadata.row_groups, schema, Some(5), None, None); + + let batches = reader.collect::>>()?; + assert_eq!(batches.len(), 2); + + assert_eq!(batches[0].len(), 5); + assert_eq!(batches[1].len(), 3); + + let result = batches[0].columns()[0] + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(result, &Int32Array::from_slice([4, 5, 6, 7, 2])); + + let result = batches[1].columns()[0] + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(result, &Int32Array::from_slice([3, 0, 1])); + + let result = batches[0].columns()[6] + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(result, &Float32Array::from_slice([0.0, 1.1, 0.0, 1.1, 0.0])); + + let result = batches[0].columns()[9] + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + result, + &BinaryViewArray::from_slice_values([[48], [49], [48], [49], [48]]) + ); + + let result = batches[1].columns()[9] + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + result, + &BinaryViewArray::from_slice_values([[49], [48], [49]]) + ); + + Ok(()) +} + +#[test] +fn read_int96_timestamps() -> PolarsResult<()> { + use std::collections::BTreeMap; + + let timestamp_data = &[ + 0x50, 0x41, 0x52, 0x31, 0x15, 0x04, 0x15, 0x48, 0x15, 0x3c, 0x4c, 0x15, 0x06, 0x15, 0x00, + 0x12, 0x00, 0x00, 0x24, 0x00, 0x00, 0x0d, 0x01, 0x08, 0x9f, 0xd5, 0x1f, 0x0d, 0x0a, 0x44, + 0x00, 0x00, 0x59, 0x68, 0x25, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x14, + 0xfb, 0x2a, 0x00, 0x15, 0x00, 0x15, 0x14, 0x15, 0x18, 0x2c, 0x15, 0x06, 0x15, 0x10, 0x15, + 0x06, 0x15, 0x06, 0x1c, 0x00, 0x00, 0x00, 0x0a, 0x24, 0x02, 0x00, 0x00, 0x00, 0x06, 0x01, + 0x02, 0x03, 0x24, 0x00, 0x26, 0x9e, 0x01, 0x1c, 0x15, 0x06, 0x19, 0x35, 0x10, 0x00, 0x06, + 0x19, 0x18, 0x0a, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x73, 0x15, 0x02, + 0x16, 0x06, 0x16, 0x9e, 0x01, 0x16, 0x96, 0x01, 0x26, 0x60, 0x26, 0x08, 0x29, 0x2c, 0x15, + 0x04, 0x15, 0x00, 0x15, 0x02, 0x00, 0x15, 0x00, 0x15, 0x10, 0x15, 0x02, 0x00, 0x00, 0x00, + 0x15, 0x04, 0x19, 0x2c, 0x35, 0x00, 0x18, 0x06, 0x73, 0x63, 0x68, 0x65, 0x6d, 0x61, 0x15, + 0x02, 0x00, 0x15, 0x06, 0x25, 0x02, 0x18, 0x0a, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, + 0x6d, 0x70, 0x73, 0x00, 0x16, 0x06, 0x19, 0x1c, 0x19, 0x1c, 0x26, 0x9e, 0x01, 0x1c, 0x15, + 0x06, 0x19, 0x35, 0x10, 0x00, 0x06, 0x19, 0x18, 0x0a, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, + 0x61, 0x6d, 0x70, 0x73, 0x15, 0x02, 0x16, 0x06, 0x16, 0x9e, 0x01, 0x16, 0x96, 0x01, 0x26, + 0x60, 0x26, 0x08, 0x29, 0x2c, 0x15, 0x04, 0x15, 0x00, 0x15, 0x02, 0x00, 0x15, 0x00, 0x15, + 0x10, 0x15, 0x02, 0x00, 0x00, 0x00, 0x16, 0x9e, 0x01, 0x16, 0x06, 0x26, 0x08, 0x16, 0x96, + 0x01, 0x14, 0x00, 0x00, 0x28, 0x20, 0x70, 0x61, 0x72, 0x71, 0x75, 0x65, 0x74, 0x2d, 0x63, + 0x70, 0x70, 0x2d, 0x61, 0x72, 0x72, 0x6f, 0x77, 0x20, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, + 0x6e, 0x20, 0x31, 0x32, 0x2e, 0x30, 0x2e, 0x30, 0x19, 0x1c, 0x1c, 0x00, 0x00, 0x00, 0x95, + 0x00, 0x00, 0x00, 0x50, 0x41, 0x52, 0x31, + ]; + + let parse = |time_unit: TimeUnit| { + let mut reader = Cursor::new(timestamp_data); + let metadata = read_metadata(&mut reader)?; + let schema = arrow::datatypes::ArrowSchema { + fields: vec![arrow::datatypes::Field::new( + "timestamps", + arrow::datatypes::ArrowDataType::Timestamp(time_unit, None), + false, + )], + metadata: BTreeMap::new(), + }; + let reader = FileReader::new(reader, metadata.row_groups, schema, Some(5), None, None); + reader.collect::>>() + }; + + // This data contains int96 timestamps in the year 1000 and 3000, which are out of range for + // Timestamp(TimeUnit::Nanoseconds) and will cause a panic in dev builds/overflow in release builds + // However, the code should work for the Microsecond/Millisecond time units + for time_unit in [ + arrow::datatypes::TimeUnit::Microsecond, + arrow::datatypes::TimeUnit::Millisecond, + arrow::datatypes::TimeUnit::Second, + ] { + parse(time_unit).expect("Should not error"); + } + std::panic::catch_unwind(|| parse(arrow::datatypes::TimeUnit::Nanosecond)) + .expect_err("Should be a panic error"); + + Ok(()) +} diff --git a/crates/polars/tests/it/io/parquet/arrow/read_indexes.rs b/crates/polars/tests/it/io/parquet/arrow/read_indexes.rs new file mode 100644 index 0000000000000..ec16f2c9a3631 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/arrow/read_indexes.rs @@ -0,0 +1,257 @@ +use std::io::Cursor; + +use arrow::array::*; +use arrow::chunk::Chunk; +use arrow::datatypes::*; +use polars_error::{PolarsError, PolarsResult}; +use polars_parquet::read::*; +use polars_parquet::write::*; + +/// Returns 2 sets of pages with different the same number of rows distributed un-evenly +fn pages( + arrays: &[&dyn Array], + encoding: Encoding, +) -> PolarsResult<(Vec, Vec, ArrowSchema)> { + // create pages with different number of rows + let array11 = PrimitiveArray::::from_slice([1, 2, 3, 4]); + let array12 = PrimitiveArray::::from_slice([5]); + let array13 = PrimitiveArray::::from_slice([6]); + + let schema = ArrowSchema::from(vec![ + Field::new("a1", ArrowDataType::Int64, false), + Field::new( + "a2", + arrays[0].data_type().clone(), + arrays.iter().map(|x| x.null_count()).sum::() != 0usize, + ), + ]); + + let parquet_schema = to_parquet_schema(&schema)?; + + let options = WriteOptions { + write_statistics: true, + compression: CompressionOptions::Uncompressed, + version: Version::V1, + data_pagesize_limit: None, + }; + + let pages1 = [array11, array12, array13] + .into_iter() + .map(|array| { + array_to_page( + &array, + parquet_schema.columns()[0] + .descriptor + .primitive_type + .clone(), + &[Nested::Primitive(None, true, array.len())], + options, + Encoding::Plain, + ) + }) + .collect::>>()?; + + let pages2 = arrays + .iter() + .flat_map(|array| { + array_to_pages( + *array, + parquet_schema.columns()[1] + .descriptor + .primitive_type + .clone(), + &[Nested::Primitive(None, true, array.len())], + options, + encoding, + ) + .unwrap() + .collect::>>() + .unwrap() + }) + .collect::>(); + + Ok((pages1, pages2, schema)) +} + +/// Tests reading pages while skipping indexes +fn read_with_indexes( + (pages1, pages2, schema): (Vec, Vec, ArrowSchema), + expected: Box, +) -> PolarsResult<()> { + let options = WriteOptions { + write_statistics: true, + compression: CompressionOptions::Uncompressed, + version: Version::V1, + data_pagesize_limit: None, + }; + + let to_compressed = |pages: Vec| { + let encoded_pages = DynIter::new(pages.into_iter().map(Ok)); + let compressed_pages = + Compressor::new(encoded_pages, options.compression, vec![]).map_err(PolarsError::from); + PolarsResult::Ok(DynStreamingIterator::new(compressed_pages)) + }; + + let row_group = DynIter::new(vec![to_compressed(pages1), to_compressed(pages2)].into_iter()); + + let writer = vec![]; + let mut writer = FileWriter::try_new(writer, schema, options)?; + + writer.write(row_group)?; + writer.end(None)?; + let data = writer.into_inner(); + + let mut reader = Cursor::new(data); + + let metadata = read_metadata(&mut reader)?; + + let schema = infer_schema(&metadata)?; + + // row group-based filtering can be done here + let row_groups = metadata.row_groups; + + // one per row group + let pages = row_groups + .iter() + .map(|row_group| { + assert!(indexes::has_indexes(row_group)); + + indexes::read_filtered_pages(&mut reader, row_group, &schema.fields, |_, intervals| { + let first_field = &intervals[0]; + let first_field_column = &first_field[0]; + assert_eq!(first_field_column.len(), 3); + let selection = [false, true, false]; + + first_field_column + .iter() + .zip(selection) + .filter(|(_i, is_selected)| *is_selected) + .map(|(i, _is_selected)| *i) + .collect() + }) + }) + .collect::>>()?; + + // apply projection pushdown + let schema = schema.filter(|index, _| index == 1); + let pages = pages + .into_iter() + .map(|pages| { + pages + .into_iter() + .enumerate() + .filter(|(index, _)| *index == 1) + .map(|(_, pages)| pages) + .collect::>() + }) + .collect::>(); + + let expected = Chunk::new(vec![expected]); + + let chunks = FileReader::new( + reader, + row_groups, + schema, + Some(1024 * 8 * 8), + None, + Some(pages), + ); + + let arrays = chunks.collect::>>()?; + + assert_eq!(arrays, vec![expected]); + Ok(()) +} + +#[test] +fn indexed_required_i64() -> PolarsResult<()> { + let array21 = Int32Array::from_slice([1, 2, 3]); + let array22 = Int32Array::from_slice([4, 5, 6]); + let expected = Int32Array::from_slice([5]).boxed(); + + read_with_indexes(pages(&[&array21, &array22], Encoding::Plain)?, expected) +} + +#[test] +fn indexed_optional_i64() -> PolarsResult<()> { + let array21 = Int32Array::from([Some(1), Some(2), None]); + let array22 = Int32Array::from([None, Some(5), Some(6)]); + let expected = Int32Array::from_slice([5]).boxed(); + + read_with_indexes(pages(&[&array21, &array22], Encoding::Plain)?, expected) +} + +#[test] +fn indexed_optional_i64_delta() -> PolarsResult<()> { + let array21 = Int32Array::from([Some(1), Some(2), None]); + let array22 = Int32Array::from([None, Some(5), Some(6)]); + let expected = Int32Array::from_slice([5]).boxed(); + + read_with_indexes( + pages(&[&array21, &array22], Encoding::DeltaBinaryPacked)?, + expected, + ) +} + +#[test] +fn indexed_required_i64_delta() -> PolarsResult<()> { + let array21 = Int32Array::from_slice([1, 2, 3]); + let array22 = Int32Array::from_slice([4, 5, 6]); + let expected = Int32Array::from_slice([5]).boxed(); + + read_with_indexes( + pages(&[&array21, &array22], Encoding::DeltaBinaryPacked)?, + expected, + ) +} + +#[test] +fn indexed_required_fixed_len() -> PolarsResult<()> { + let array21 = FixedSizeBinaryArray::from_slice([[127], [128], [129]]); + let array22 = FixedSizeBinaryArray::from_slice([[130], [131], [132]]); + let expected = FixedSizeBinaryArray::from_slice([[131]]).boxed(); + + read_with_indexes(pages(&[&array21, &array22], Encoding::Plain)?, expected) +} + +#[test] +fn indexed_optional_fixed_len() -> PolarsResult<()> { + let array21 = FixedSizeBinaryArray::from([Some([127]), Some([128]), None]); + let array22 = FixedSizeBinaryArray::from([None, Some([131]), Some([132])]); + let expected = FixedSizeBinaryArray::from_slice([[131]]).boxed(); + + read_with_indexes(pages(&[&array21, &array22], Encoding::Plain)?, expected) +} + +#[test] +fn indexed_required_boolean() -> PolarsResult<()> { + let array21 = BooleanArray::from_slice([true, false, true]); + let array22 = BooleanArray::from_slice([false, false, true]); + let expected = BooleanArray::from_slice([false]).boxed(); + + read_with_indexes(pages(&[&array21, &array22], Encoding::Plain)?, expected) +} + +#[test] +fn indexed_optional_boolean() -> PolarsResult<()> { + let array21 = BooleanArray::from([Some(true), Some(false), None]); + let array22 = BooleanArray::from([None, Some(false), Some(true)]); + let expected = BooleanArray::from_slice([false]).boxed(); + + read_with_indexes(pages(&[&array21, &array22], Encoding::Plain)?, expected) +} + +#[test] +fn indexed_dict() -> PolarsResult<()> { + let indices = PrimitiveArray::from_values((0..6u64).map(|x| x % 2)); + let values = PrimitiveArray::from_slice([4i64, 6i64]).boxed(); + let array = DictionaryArray::try_from_keys(indices, values).unwrap(); + + let indices = PrimitiveArray::from_slice([0u64]); + let values = PrimitiveArray::from_slice([4i64, 6i64]).boxed(); + let expected = DictionaryArray::try_from_keys(indices, values).unwrap(); + + let expected = expected.boxed(); + + read_with_indexes(pages(&[&array], Encoding::RleDictionary)?, expected) +} diff --git a/crates/polars/tests/it/io/parquet/arrow/sample_tests.rs b/crates/polars/tests/it/io/parquet/arrow/sample_tests.rs new file mode 100644 index 0000000000000..a577ee0efe7b1 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/arrow/sample_tests.rs @@ -0,0 +1,115 @@ +use std::borrow::Borrow; +use std::io::Cursor; + +use arrow2::chunk::Chunk; +use arrow2::datatypes::{Field, Metadata, Schema}; +use arrow2::error::Result; +use arrow2::io::parquet::read as p_read; +use arrow2::io::parquet::write::*; +use sample_arrow2::array::ArbitraryArray; +use sample_arrow2::chunk::{ArbitraryChunk, ChainedChunk}; +use sample_arrow2::datatypes::{sample_flat, ArbitraryArrowDataType}; +use sample_std::{Chance, Random, Regex, Sample}; +use sample_test::sample_test; + +fn deep_chunk(depth: usize, len: usize) -> ArbitraryChunk { + let names = Regex::new("[a-z]{4,8}"); + let data_type = ArbitraryArrowDataType { + struct_branch: 1..3, + names: names.clone(), + // TODO: this breaks the test + // nullable: Chance(0.5), + nullable: Chance(0.0), + flat: sample_flat, + } + .sample_depth(depth); + + let array = ArbitraryArray { + names, + branch: 0..10, + len: len..(len + 1), + null: Chance(0.1), + // TODO: this breaks the test + // is_nullable: true, + is_nullable: false, + }; + + ArbitraryChunk { + // TODO: shrinking appears to be an issue with chunks this large. issues + // currently reproduce on the smaller sizes anyway. + // chunk_len: 10..1000, + chunk_len: 1..10, + array_count: 1..2, + data_type, + array, + } +} + +#[sample_test] +fn round_trip_sample( + #[sample(deep_chunk(5, 100).sample_one())] chained: ChainedChunk, +) -> Result<()> { + sample_test::env_logger_init(); + let chunks = vec![chained.value]; + let name = Regex::new("[a-z]{4, 8}"); + let mut g = Random::new(); + + // TODO: this probably belongs in a helper in sample-arrow2 + let schema = Schema { + fields: chunks + .first() + .unwrap() + .iter() + .map(|arr| { + Field::new( + name.generate(&mut g), + arr.data_type().clone(), + arr.validity().is_some(), + ) + }) + .collect(), + metadata: Metadata::default(), + }; + + let options = WriteOptions { + write_statistics: true, + compression: CompressionOptions::Uncompressed, + version: Version::V2, + data_pagesize_limit: None, + }; + + let encodings: Vec<_> = schema + .borrow() + .fields + .iter() + .map(|field| transverse(field.data_type(), |_| Encoding::Plain)) + .collect(); + + let row_groups = RowGroupIterator::try_new( + chunks.clone().into_iter().map(Ok), + &schema, + options, + encodings, + )?; + + let buffer = Cursor::new(vec![]); + let mut writer = FileWriter::try_new(buffer, schema, options)?; + + for group in row_groups { + writer.write(group?)?; + } + writer.end(None)?; + + let mut buffer = writer.into_inner(); + + let metadata = p_read::read_metadata(&mut buffer)?; + let schema = p_read::infer_schema(&metadata)?; + + let mut reader = p_read::FileReader::new(buffer, metadata.row_groups, schema, None, None, None); + + let result: Vec<_> = reader.collect::>()?; + + assert_eq!(result, chunks); + + Ok(()) +} diff --git a/crates/polars/tests/it/io/parquet/arrow/write.rs b/crates/polars/tests/it/io/parquet/arrow/write.rs new file mode 100644 index 0000000000000..2f1f35d9d4564 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/arrow/write.rs @@ -0,0 +1,491 @@ +use polars_parquet::arrow::write::*; + +use super::*; + +fn round_trip( + column: &str, + file: &str, + version: Version, + compression: CompressionOptions, + encodings: Vec, +) -> PolarsResult<()> { + round_trip_opt_stats(column, file, version, compression, encodings, true) +} + +fn round_trip_opt_stats( + column: &str, + file: &str, + version: Version, + compression: CompressionOptions, + encodings: Vec, + check_stats: bool, +) -> PolarsResult<()> { + let (array, statistics) = match file { + "nested" => ( + pyarrow_nested_nullable(column), + pyarrow_nested_nullable_statistics(column), + ), + "nullable" => ( + pyarrow_nullable(column), + pyarrow_nullable_statistics(column), + ), + "required" => ( + pyarrow_required(column), + pyarrow_required_statistics(column), + ), + "struct" => (pyarrow_struct(column), pyarrow_struct_statistics(column)), + "nested_edge" => ( + pyarrow_nested_edge(column), + pyarrow_nested_edge_statistics(column), + ), + _ => unreachable!(), + }; + + let field = Field::new("a1", array.data_type().clone(), true); + let schema = ArrowSchema::from(vec![field]); + + let options = WriteOptions { + write_statistics: true, + compression, + version, + data_pagesize_limit: None, + }; + + let iter = vec![Chunk::try_new(vec![array.clone()])]; + + let row_groups = + RowGroupIterator::try_new(iter.into_iter(), &schema, options, vec![encodings])?; + + let writer = Cursor::new(vec![]); + let mut writer = FileWriter::try_new(writer, schema, options)?; + + for group in row_groups { + writer.write(group?)?; + } + writer.end(None)?; + + let data = writer.into_inner().into_inner(); + + let (result, stats) = read_column(&mut Cursor::new(data), "a1")?; + + assert_eq!(array.as_ref(), result.as_ref()); + if check_stats { + assert_eq!(statistics, stats); + } + Ok(()) +} + +#[test] +fn int64_optional_v1() -> PolarsResult<()> { + round_trip( + "int64", + "nullable", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn int64_required_v1() -> PolarsResult<()> { + round_trip( + "int64", + "required", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn int64_optional_v2() -> PolarsResult<()> { + round_trip( + "int64", + "nullable", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn int64_optional_delta() -> PolarsResult<()> { + round_trip( + "int64", + "nullable", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::DeltaBinaryPacked], + ) +} + +#[test] +fn int64_required_delta() -> PolarsResult<()> { + round_trip( + "int64", + "required", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::DeltaBinaryPacked], + ) +} + +#[cfg(feature = "parquet")] +#[test] +fn int64_optional_v2_compressed() -> PolarsResult<()> { + round_trip( + "int64", + "nullable", + Version::V2, + CompressionOptions::Snappy, + vec![Encoding::Plain], + ) +} + +#[test] +fn utf8_optional_v1() -> PolarsResult<()> { + round_trip( + "string", + "nullable", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn utf8_required_v1() -> PolarsResult<()> { + round_trip( + "string", + "required", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn utf8_optional_v2() -> PolarsResult<()> { + round_trip( + "string", + "nullable", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn utf8_required_v2() -> PolarsResult<()> { + round_trip( + "string", + "required", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[cfg(feature = "parquet")] +#[test] +fn utf8_optional_v2_compressed() -> PolarsResult<()> { + round_trip( + "string", + "nullable", + Version::V2, + CompressionOptions::Snappy, + vec![Encoding::Plain], + ) +} + +#[cfg(feature = "parquet")] +#[test] +fn utf8_required_v2_compressed() -> PolarsResult<()> { + round_trip( + "string", + "required", + Version::V2, + CompressionOptions::Snappy, + vec![Encoding::Plain], + ) +} + +#[test] +fn bool_optional_v1() -> PolarsResult<()> { + round_trip( + "bool", + "nullable", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn bool_required_v1() -> PolarsResult<()> { + round_trip( + "bool", + "required", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn bool_optional_v2_uncompressed() -> PolarsResult<()> { + round_trip( + "bool", + "nullable", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn bool_required_v2_uncompressed() -> PolarsResult<()> { + round_trip( + "bool", + "required", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[cfg(feature = "parquet")] +#[test] +fn bool_required_v2_compressed() -> PolarsResult<()> { + round_trip( + "bool", + "required", + Version::V2, + CompressionOptions::Snappy, + vec![Encoding::Plain], + ) +} + +#[test] +fn list_int64_optional_v2() -> PolarsResult<()> { + round_trip( + "list_int64", + "nested", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn list_int64_optional_v1() -> PolarsResult<()> { + round_trip( + "list_int64", + "nested", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn list_int64_required_required_v1() -> PolarsResult<()> { + round_trip( + "list_int64_required_required", + "nested", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn list_int64_required_required_v2() -> PolarsResult<()> { + round_trip( + "list_int64_required_required", + "nested", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn list_bool_optional_v2() -> PolarsResult<()> { + round_trip( + "list_bool", + "nested", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn list_bool_optional_v1() -> PolarsResult<()> { + round_trip( + "list_bool", + "nested", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn list_utf8_optional_v2() -> PolarsResult<()> { + round_trip( + "list_utf8", + "nested", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn list_utf8_optional_v1() -> PolarsResult<()> { + round_trip( + "list_utf8", + "nested", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn list_nested_inner_required_required_i64() -> PolarsResult<()> { + round_trip_opt_stats( + "list_nested_inner_required_required_i64", + "nested", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + false, + ) +} + +#[test] +fn v1_nested_struct_list_nullable() -> PolarsResult<()> { + round_trip_opt_stats( + "struct_list_nullable", + "nested", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + true, + ) +} + +#[test] +fn v1_nested_list_struct_list_nullable() -> PolarsResult<()> { + round_trip_opt_stats( + "list_struct_list_nullable", + "nested", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + true, + ) +} + +#[test] +fn utf8_optional_v2_delta() -> PolarsResult<()> { + round_trip( + "string", + "nullable", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::DeltaLengthByteArray], + ) +} + +#[test] +fn utf8_required_v2_delta() -> PolarsResult<()> { + round_trip( + "string", + "required", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::DeltaLengthByteArray], + ) +} + +#[cfg(feature = "parquet")] +#[test] +fn i64_optional_v2_dict_compressed() -> PolarsResult<()> { + round_trip( + "int32_dict", + "nullable", + Version::V2, + CompressionOptions::Snappy, + vec![Encoding::RleDictionary], + ) +} + +#[test] +fn struct_v1() -> PolarsResult<()> { + round_trip( + "struct", + "struct", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain, Encoding::Plain], + ) +} + +#[test] +fn struct_v2() -> PolarsResult<()> { + round_trip( + "struct", + "struct", + Version::V2, + CompressionOptions::Uncompressed, + vec![Encoding::Plain, Encoding::Plain], + ) +} + +#[test] +fn nested_edge_simple() -> PolarsResult<()> { + round_trip( + "simple", + "nested_edge", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn nested_edge_null() -> PolarsResult<()> { + round_trip( + "null", + "nested_edge", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn v1_nested_edge_struct_list_nullable() -> PolarsResult<()> { + round_trip( + "struct_list_nullable", + "nested_edge", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} + +#[test] +fn nested_edge_list_struct_list_nullable() -> PolarsResult<()> { + round_trip( + "list_struct_list_nullable", + "nested_edge", + Version::V1, + CompressionOptions::Uncompressed, + vec![Encoding::Plain], + ) +} diff --git a/crates/polars/tests/it/io/parquet/mod.rs b/crates/polars/tests/it/io/parquet/mod.rs new file mode 100644 index 0000000000000..ba6bbe5dc7240 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/mod.rs @@ -0,0 +1,207 @@ +#![forbid(unsafe_code)] +mod arrow; +mod read; +mod roundtrip; +mod write; + +use std::io::Cursor; +use std::path::PathBuf; + +use polars::prelude::*; + +// The dynamic representation of values in native Rust. This is not exhaustive. +// todo: maybe refactor this into serde/json? +#[derive(Debug, PartialEq)] +pub enum Array { + Int32(Vec>), + Int64(Vec>), + Int96(Vec>), + Float(Vec>), + Double(Vec>), + Boolean(Vec>), + Binary(Vec>>), + FixedLenBinary(Vec>>), + List(Vec>), + Struct(Vec, Vec), +} + +use std::sync::Arc; + +use polars_parquet::parquet::schema::types::{PhysicalType, PrimitiveType}; +use polars_parquet::parquet::statistics::*; + +pub fn alltypes_plain(column: &str) -> Array { + match column { + "id" => { + let expected = vec![4, 5, 6, 7, 2, 3, 0, 1]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Int32(expected) + }, + "id-short-array" => { + let expected = vec![4]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Int32(expected) + }, + "bool_col" => { + let expected = vec![true, false, true, false, true, false, true, false]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Boolean(expected) + }, + "tinyint_col" => { + let expected = vec![0, 1, 0, 1, 0, 1, 0, 1]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Int32(expected) + }, + "smallint_col" => { + let expected = vec![0, 1, 0, 1, 0, 1, 0, 1]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Int32(expected) + }, + "int_col" => { + let expected = vec![0, 1, 0, 1, 0, 1, 0, 1]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Int32(expected) + }, + "bigint_col" => { + let expected = vec![0, 10, 0, 10, 0, 10, 0, 10]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Int64(expected) + }, + "float_col" => { + let expected = vec![0.0, 1.1, 0.0, 1.1, 0.0, 1.1, 0.0, 1.1]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Float(expected) + }, + "double_col" => { + let expected = vec![0.0, 10.1, 0.0, 10.1, 0.0, 10.1, 0.0, 10.1]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Double(expected) + }, + "date_string_col" => { + let expected = vec![ + vec![48, 51, 47, 48, 49, 47, 48, 57], + vec![48, 51, 47, 48, 49, 47, 48, 57], + vec![48, 52, 47, 48, 49, 47, 48, 57], + vec![48, 52, 47, 48, 49, 47, 48, 57], + vec![48, 50, 47, 48, 49, 47, 48, 57], + vec![48, 50, 47, 48, 49, 47, 48, 57], + vec![48, 49, 47, 48, 49, 47, 48, 57], + vec![48, 49, 47, 48, 49, 47, 48, 57], + ]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Binary(expected) + }, + "string_col" => { + let expected = vec![ + vec![48], + vec![49], + vec![48], + vec![49], + vec![48], + vec![49], + vec![48], + vec![49], + ]; + let expected = expected.into_iter().map(Some).collect::>(); + Array::Binary(expected) + }, + "timestamp_col" => { + todo!() + }, + _ => unreachable!(), + } +} + +pub fn alltypes_statistics(column: &str) -> Arc { + match column { + "id" => Arc::new(PrimitiveStatistics:: { + primitive_type: PrimitiveType::from_physical("col".to_string(), PhysicalType::Int32), + null_count: Some(0), + distinct_count: None, + min_value: Some(0), + max_value: Some(7), + }), + "id-short-array" => Arc::new(PrimitiveStatistics:: { + primitive_type: PrimitiveType::from_physical("col".to_string(), PhysicalType::Int32), + null_count: Some(0), + distinct_count: None, + min_value: Some(4), + max_value: Some(4), + }), + "bool_col" => Arc::new(BooleanStatistics { + null_count: Some(0), + distinct_count: None, + min_value: Some(false), + max_value: Some(true), + }), + "tinyint_col" | "smallint_col" | "int_col" => Arc::new(PrimitiveStatistics:: { + primitive_type: PrimitiveType::from_physical("col".to_string(), PhysicalType::Int32), + null_count: Some(0), + distinct_count: None, + min_value: Some(0), + max_value: Some(1), + }), + "bigint_col" => Arc::new(PrimitiveStatistics:: { + primitive_type: PrimitiveType::from_physical("col".to_string(), PhysicalType::Int64), + null_count: Some(0), + distinct_count: None, + min_value: Some(0), + max_value: Some(10), + }), + "float_col" => Arc::new(PrimitiveStatistics:: { + primitive_type: PrimitiveType::from_physical("col".to_string(), PhysicalType::Float), + null_count: Some(0), + distinct_count: None, + min_value: Some(0.0), + max_value: Some(1.1), + }), + "double_col" => Arc::new(PrimitiveStatistics:: { + primitive_type: PrimitiveType::from_physical("col".to_string(), PhysicalType::Double), + null_count: Some(0), + distinct_count: None, + min_value: Some(0.0), + max_value: Some(10.1), + }), + "date_string_col" => Arc::new(BinaryStatistics { + primitive_type: PrimitiveType::from_physical( + "col".to_string(), + PhysicalType::ByteArray, + ), + null_count: Some(0), + distinct_count: None, + min_value: Some(vec![48, 49, 47, 48, 49, 47, 48, 57]), + max_value: Some(vec![48, 52, 47, 48, 49, 47, 48, 57]), + }), + "string_col" => Arc::new(BinaryStatistics { + primitive_type: PrimitiveType::from_physical( + "col".to_string(), + PhysicalType::ByteArray, + ), + null_count: Some(0), + distinct_count: None, + min_value: Some(vec![48]), + max_value: Some(vec![49]), + }), + "timestamp_col" => { + todo!() + }, + _ => unreachable!(), + } +} + +#[test] +fn test_vstack_empty_3220() -> PolarsResult<()> { + let df1 = df! { + "a" => ["1", "2"], + "b" => [1, 2] + }?; + let empty_df = df1.head(Some(0)); + let mut stacked = df1.clone(); + stacked.vstack_mut(&empty_df)?; + stacked.vstack_mut(&df1)?; + let mut buf = Cursor::new(Vec::new()); + ParquetWriter::new(&mut buf).finish(&mut stacked)?; + let read_df = ParquetReader::new(buf).finish()?; + assert!(stacked.equals(&read_df)); + Ok(()) +} diff --git a/crates/polars/tests/it/io/parquet/read/binary.rs b/crates/polars/tests/it/io/parquet/read/binary.rs new file mode 100644 index 0000000000000..5311c8066e148 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/binary.rs @@ -0,0 +1,33 @@ +use polars_parquet::parquet::deserialize::FixedLenBinaryPageState; +use polars_parquet::parquet::error::Result; +use polars_parquet::parquet::page::DataPage; + +use super::dictionary::BinaryPageDict; +use super::utils::deserialize_optional; + +pub fn page_to_vec(page: &DataPage, dict: Option<&BinaryPageDict>) -> Result>>> { + assert_eq!(page.descriptor.max_rep_level, 0); + + let state = FixedLenBinaryPageState::try_new(page, dict)?; + + match state { + FixedLenBinaryPageState::Optional(validity, values) => { + deserialize_optional(validity, values.map(|x| Ok(x.to_vec()))) + }, + FixedLenBinaryPageState::Required(values) => values + .map(|x| Ok(x.to_vec())) + .map(Some) + .map(|x| x.transpose()) + .collect(), + FixedLenBinaryPageState::RequiredDictionary(dict) => dict + .indexes + .map(|x| dict.dict.value(x as usize).map(|x| x.to_vec()).map(Some)) + .collect(), + FixedLenBinaryPageState::OptionalDictionary(validity, dict) => { + let values = dict + .indexes + .map(|x| dict.dict.value(x as usize).map(|x| x.to_vec())); + deserialize_optional(validity, values) + }, + } +} diff --git a/crates/polars/tests/it/io/parquet/read/boolean.rs b/crates/polars/tests/it/io/parquet/read/boolean.rs new file mode 100644 index 0000000000000..7642f4023fff2 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/boolean.rs @@ -0,0 +1,20 @@ +use polars_parquet::parquet::deserialize::BooleanPageState; +use polars_parquet::parquet::encoding::hybrid_rle::BitmapIter; +use polars_parquet::parquet::error::Result; +use polars_parquet::parquet::page::DataPage; + +use super::utils::deserialize_optional; + +pub fn page_to_vec(page: &DataPage) -> Result>> { + assert_eq!(page.descriptor.max_rep_level, 0); + let state = BooleanPageState::try_new(page)?; + + match state { + BooleanPageState::Optional(validity, mut values) => { + deserialize_optional(validity, values.by_ref().map(Ok)) + }, + BooleanPageState::Required(bitmap, length) => { + Ok(BitmapIter::new(bitmap, 0, length).map(Some).collect()) + }, + } +} diff --git a/crates/polars/tests/it/io/parquet/read/deserialize.rs b/crates/polars/tests/it/io/parquet/read/deserialize.rs new file mode 100644 index 0000000000000..1b5cf18b14528 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/deserialize.rs @@ -0,0 +1,314 @@ +use polars_parquet::parquet::deserialize::{ + FilteredHybridBitmapIter, FilteredHybridEncoded, HybridEncoded, +}; +use polars_parquet::parquet::indexes::Interval; + +#[test] +fn bitmap_incomplete() { + let mut iter = FilteredHybridBitmapIter::new( + vec![Ok(HybridEncoded::Bitmap(&[0b01000011], 7))].into_iter(), + vec![Interval::new(1, 2)].into(), + ); + let a = iter.by_ref().collect::, _>>().unwrap(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![ + FilteredHybridEncoded::Skipped(1), + FilteredHybridEncoded::Bitmap { + values: &[0b01000011], + offset: 1, + length: 2, + } + ] + ); +} + +#[test] +fn bitmap_complete() { + let mut iter = FilteredHybridBitmapIter::new( + vec![Ok(HybridEncoded::Bitmap(&[0b01000011], 8))].into_iter(), + vec![Interval::new(0, 8)].into(), + ); + let a = iter.by_ref().collect::, _>>().unwrap(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![FilteredHybridEncoded::Bitmap { + values: &[0b01000011], + offset: 0, + length: 8, + }] + ); +} + +#[test] +fn bitmap_interval_incomplete() { + let mut iter = FilteredHybridBitmapIter::new( + vec![ + Ok(HybridEncoded::Bitmap(&[0b01000011], 8)), + Ok(HybridEncoded::Bitmap(&[0b11111111], 8)), + ] + .into_iter(), + vec![Interval::new(0, 10)].into(), + ); + let a = iter.by_ref().collect::, _>>().unwrap(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![ + FilteredHybridEncoded::Bitmap { + values: &[0b01000011], + offset: 0, + length: 8, + }, + FilteredHybridEncoded::Bitmap { + values: &[0b11111111], + offset: 0, + length: 2, + } + ] + ); +} + +#[test] +fn bitmap_interval_run_incomplete() { + let mut iter = FilteredHybridBitmapIter::new( + vec![ + Ok(HybridEncoded::Bitmap(&[0b01100011], 8)), + Ok(HybridEncoded::Bitmap(&[0b11111111], 8)), + ] + .into_iter(), + vec![Interval::new(0, 5), Interval::new(7, 4)].into(), + ); + let a = iter.by_ref().collect::, _>>().unwrap(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![ + FilteredHybridEncoded::Bitmap { + values: &[0b01100011], + offset: 0, + length: 5, + }, + FilteredHybridEncoded::Skipped(2), + FilteredHybridEncoded::Bitmap { + values: &[0b01100011], + offset: 7, + length: 1, + }, + FilteredHybridEncoded::Bitmap { + values: &[0b11111111], + offset: 0, + length: 3, + } + ] + ); +} + +#[test] +fn bitmap_interval_run_skipped() { + let mut iter = FilteredHybridBitmapIter::new( + vec![ + Ok(HybridEncoded::Bitmap(&[0b01100011], 8)), + Ok(HybridEncoded::Bitmap(&[0b11111111], 8)), + ] + .into_iter(), + vec![Interval::new(9, 2)].into(), + ); + let a = iter.by_ref().collect::, _>>().unwrap(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![ + FilteredHybridEncoded::Skipped(4), + FilteredHybridEncoded::Skipped(1), + FilteredHybridEncoded::Bitmap { + values: &[0b11111111], + offset: 1, + length: 2, + }, + ] + ); +} + +#[test] +fn bitmap_interval_run_offset_skipped() { + let mut iter = FilteredHybridBitmapIter::new( + vec![ + Ok(HybridEncoded::Bitmap(&[0b01100011], 8)), + Ok(HybridEncoded::Bitmap(&[0b11111111], 8)), + ] + .into_iter(), + vec![Interval::new(0, 1), Interval::new(9, 2)].into(), + ); + let a = iter.by_ref().collect::, _>>().unwrap(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![ + FilteredHybridEncoded::Bitmap { + values: &[0b01100011], + offset: 0, + length: 1, + }, + FilteredHybridEncoded::Skipped(3), + FilteredHybridEncoded::Skipped(1), + FilteredHybridEncoded::Bitmap { + values: &[0b11111111], + offset: 1, + length: 2, + }, + ] + ); +} + +#[test] +fn repeated_incomplete() { + let mut iter = FilteredHybridBitmapIter::new( + vec![Ok(HybridEncoded::Repeated(true, 7))].into_iter(), + vec![Interval::new(1, 2)].into(), + ); + let a = iter.by_ref().collect::, _>>().unwrap(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![ + FilteredHybridEncoded::Skipped(1), + FilteredHybridEncoded::Repeated { + is_set: true, + length: 2, + } + ] + ); +} + +#[test] +fn repeated_complete() { + let mut iter = FilteredHybridBitmapIter::new( + vec![Ok(HybridEncoded::Repeated(true, 8))].into_iter(), + vec![Interval::new(0, 8)].into(), + ); + let a = iter.by_ref().collect::, _>>().unwrap(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![FilteredHybridEncoded::Repeated { + is_set: true, + length: 8, + }] + ); +} + +#[test] +fn repeated_interval_incomplete() { + let mut iter = FilteredHybridBitmapIter::new( + vec![ + Ok(HybridEncoded::Repeated(true, 8)), + Ok(HybridEncoded::Repeated(false, 8)), + ] + .into_iter(), + vec![Interval::new(0, 10)].into(), + ); + let a = iter.by_ref().collect::, _>>().unwrap(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![ + FilteredHybridEncoded::Repeated { + is_set: true, + length: 8, + }, + FilteredHybridEncoded::Repeated { + is_set: false, + length: 2, + } + ] + ); +} + +#[test] +fn repeated_interval_run_incomplete() { + let mut iter = FilteredHybridBitmapIter::new( + vec![ + Ok(HybridEncoded::Repeated(true, 8)), + Ok(HybridEncoded::Repeated(false, 8)), + ] + .into_iter(), + vec![Interval::new(0, 5), Interval::new(7, 4)].into(), + ); + let a = iter.by_ref().collect::, _>>().unwrap(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![ + FilteredHybridEncoded::Repeated { + is_set: true, + length: 5, + }, + FilteredHybridEncoded::Skipped(2), + FilteredHybridEncoded::Repeated { + is_set: true, + length: 1, + }, + FilteredHybridEncoded::Repeated { + is_set: false, + length: 3, + } + ] + ); +} + +#[test] +fn repeated_interval_run_skipped() { + let mut iter = FilteredHybridBitmapIter::new( + vec![ + Ok(HybridEncoded::Repeated(true, 8)), + Ok(HybridEncoded::Repeated(false, 8)), + ] + .into_iter(), + vec![Interval::new(9, 2)].into(), + ); + let a = iter.by_ref().collect::, _>>().unwrap(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![ + FilteredHybridEncoded::Skipped(8), + FilteredHybridEncoded::Skipped(0), + FilteredHybridEncoded::Repeated { + is_set: false, + length: 2, + }, + ] + ); +} + +#[test] +fn repeated_interval_run_offset_skipped() { + let mut iter = FilteredHybridBitmapIter::new( + vec![ + Ok(HybridEncoded::Repeated(true, 8)), + Ok(HybridEncoded::Repeated(false, 8)), + ] + .into_iter(), + vec![Interval::new(0, 1), Interval::new(9, 2)].into(), + ); + let a = iter.by_ref().collect::, _>>().unwrap(); + assert_eq!(iter.len(), 0); + assert_eq!( + a, + vec![ + FilteredHybridEncoded::Repeated { + is_set: true, + length: 1, + }, + FilteredHybridEncoded::Skipped(7), + FilteredHybridEncoded::Skipped(0), + FilteredHybridEncoded::Repeated { + is_set: false, + length: 2, + }, + ] + ); +} diff --git a/crates/polars/tests/it/io/parquet/read/dictionary/binary.rs b/crates/polars/tests/it/io/parquet/read/dictionary/binary.rs new file mode 100644 index 0000000000000..8b9bce7c50e7a --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/dictionary/binary.rs @@ -0,0 +1,48 @@ +use polars_parquet::parquet::encoding::get_length; +use polars_parquet::parquet::error::Error; + +#[derive(Debug)] +pub struct BinaryPageDict { + values: Vec>, +} + +impl BinaryPageDict { + pub fn new(values: Vec>) -> Self { + Self { values } + } + + #[inline] + pub fn value(&self, index: usize) -> Result<&[u8], Error> { + self.values + .get(index) + .map(|x| x.as_ref()) + .ok_or_else(|| Error::OutOfSpec("invalid index".to_string())) + } +} + +fn read_plain(bytes: &[u8], length: usize) -> Result>, Error> { + let mut bytes = bytes; + let mut values = Vec::new(); + + for _ in 0..length { + let slot_length = get_length(bytes).unwrap(); + bytes = &bytes[4..]; + + if slot_length > bytes.len() { + return Err(Error::OutOfSpec( + "The string on a dictionary page has a length that is out of bounds".to_string(), + )); + } + let (result, remaining) = bytes.split_at(slot_length); + + values.push(result.to_vec()); + bytes = remaining; + } + + Ok(values) +} + +pub fn read(buf: &[u8], num_values: usize) -> Result { + let values = read_plain(buf, num_values)?; + Ok(BinaryPageDict::new(values)) +} diff --git a/crates/polars/tests/it/io/parquet/read/dictionary/fixed_len_binary.rs b/crates/polars/tests/it/io/parquet/read/dictionary/fixed_len_binary.rs new file mode 100644 index 0000000000000..31b150fcb820b --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/dictionary/fixed_len_binary.rs @@ -0,0 +1,31 @@ +use polars_parquet::parquet::error::{Error, Result}; + +#[derive(Debug)] +pub struct FixedLenByteArrayPageDict { + values: Vec, + size: usize, +} + +impl FixedLenByteArrayPageDict { + pub fn new(values: Vec, size: usize) -> Self { + Self { values, size } + } + + #[inline] + pub fn value(&self, index: usize) -> Result<&[u8]> { + self.values + .get(index * self.size..(index + 1) * self.size) + .ok_or_else(|| { + Error::OutOfSpec( + "The data page has an index larger than the dictionary page values".to_string(), + ) + }) + } +} + +pub fn read(buf: &[u8], size: usize, num_values: usize) -> Result { + let length = size.saturating_mul(num_values); + let values = buf.get(..length).ok_or_else(|| Error::OutOfSpec("Fixed sized binary declares a number of values times size larger than the page buffer".to_string()))?.to_vec(); + + Ok(FixedLenByteArrayPageDict::new(values, size)) +} diff --git a/crates/polars/tests/it/io/parquet/read/dictionary/mod.rs b/crates/polars/tests/it/io/parquet/read/dictionary/mod.rs new file mode 100644 index 0000000000000..4dcb2afbaf707 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/dictionary/mod.rs @@ -0,0 +1,56 @@ +mod binary; +mod fixed_len_binary; +mod primitive; + +pub use binary::BinaryPageDict; +pub use fixed_len_binary::FixedLenByteArrayPageDict; +use polars_parquet::parquet::error::{Error, Result}; +use polars_parquet::parquet::page::DictPage; +use polars_parquet::parquet::schema::types::PhysicalType; +pub use primitive::PrimitivePageDict; + +pub enum DecodedDictPage { + Int32(PrimitivePageDict), + Int64(PrimitivePageDict), + Int96(PrimitivePageDict<[u32; 3]>), + Float(PrimitivePageDict), + Double(PrimitivePageDict), + ByteArray(BinaryPageDict), + FixedLenByteArray(FixedLenByteArrayPageDict), +} + +pub fn deserialize(page: &DictPage, physical_type: PhysicalType) -> Result { + _deserialize(&page.buffer, page.num_values, page.is_sorted, physical_type) +} + +fn _deserialize( + buf: &[u8], + num_values: usize, + is_sorted: bool, + physical_type: PhysicalType, +) -> Result { + match physical_type { + PhysicalType::Boolean => Err(Error::OutOfSpec( + "Boolean physical type cannot be dictionary-encoded".to_string(), + )), + PhysicalType::Int32 => { + primitive::read::(buf, num_values, is_sorted).map(DecodedDictPage::Int32) + }, + PhysicalType::Int64 => { + primitive::read::(buf, num_values, is_sorted).map(DecodedDictPage::Int64) + }, + PhysicalType::Int96 => { + primitive::read::<[u32; 3]>(buf, num_values, is_sorted).map(DecodedDictPage::Int96) + }, + PhysicalType::Float => { + primitive::read::(buf, num_values, is_sorted).map(DecodedDictPage::Float) + }, + PhysicalType::Double => { + primitive::read::(buf, num_values, is_sorted).map(DecodedDictPage::Double) + }, + PhysicalType::ByteArray => binary::read(buf, num_values).map(DecodedDictPage::ByteArray), + PhysicalType::FixedLenByteArray(size) => { + fixed_len_binary::read(buf, size, num_values).map(DecodedDictPage::FixedLenByteArray) + }, + } +} diff --git a/crates/polars/tests/it/io/parquet/read/dictionary/primitive.rs b/crates/polars/tests/it/io/parquet/read/dictionary/primitive.rs new file mode 100644 index 0000000000000..aeeccf10eb5b4 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/dictionary/primitive.rs @@ -0,0 +1,47 @@ +use polars_parquet::parquet::error::{Error, Result}; +use polars_parquet::parquet::types::{decode, NativeType}; + +#[derive(Debug)] +pub struct PrimitivePageDict { + values: Vec, +} + +impl PrimitivePageDict { + pub fn new(values: Vec) -> Self { + Self { values } + } + + pub fn values(&self) -> &[T] { + &self.values + } + + #[inline] + pub fn value(&self, index: usize) -> Result<&T> { + self.values.get(index).ok_or_else(|| { + Error::OutOfSpec( + "The data page has an index larger than the dictionary page values".to_string(), + ) + }) + } +} + +pub fn read( + buf: &[u8], + num_values: usize, + _is_sorted: bool, +) -> Result> { + let size_of = std::mem::size_of::(); + + let typed_size = num_values.wrapping_mul(size_of); + + let values = buf.get(..typed_size).ok_or_else(|| { + Error::OutOfSpec( + "The number of values declared in the dict page does not match the length of the page" + .to_string(), + ) + })?; + + let values = values.chunks_exact(size_of).map(decode::).collect(); + + Ok(PrimitivePageDict::new(values)) +} diff --git a/crates/polars/tests/it/io/parquet/read/fixed_binary.rs b/crates/polars/tests/it/io/parquet/read/fixed_binary.rs new file mode 100644 index 0000000000000..b5e4750a1406e --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/fixed_binary.rs @@ -0,0 +1,34 @@ +use polars_parquet::parquet::deserialize::FixedLenBinaryPageState; +use polars_parquet::parquet::error::Result; +use polars_parquet::parquet::page::DataPage; + +use super::dictionary::FixedLenByteArrayPageDict; +use super::utils::deserialize_optional; + +pub fn page_to_vec( + page: &DataPage, + dict: Option<&FixedLenByteArrayPageDict>, +) -> Result>>> { + assert_eq!(page.descriptor.max_rep_level, 0); + + let state = FixedLenBinaryPageState::try_new(page, dict)?; + + match state { + FixedLenBinaryPageState::Optional(validity, values) => { + deserialize_optional(validity, values.map(|x| Ok(x.to_vec()))) + }, + FixedLenBinaryPageState::Required(values) => { + Ok(values.map(|x| x.to_vec()).map(Some).collect()) + }, + FixedLenBinaryPageState::RequiredDictionary(dict) => dict + .indexes + .map(|x| dict.dict.value(x as usize).map(|x| x.to_vec()).map(Some)) + .collect(), + FixedLenBinaryPageState::OptionalDictionary(validity, dict) => { + let values = dict + .indexes + .map(|x| dict.dict.value(x as usize).map(|x| x.to_vec())); + deserialize_optional(validity, values) + }, + } +} diff --git a/crates/polars/tests/it/io/parquet/read/indexes.rs b/crates/polars/tests/it/io/parquet/read/indexes.rs new file mode 100644 index 0000000000000..ad79c6d04544f --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/indexes.rs @@ -0,0 +1,143 @@ +use polars_parquet::parquet::error::Error; +use polars_parquet::parquet::indexes::{ + BooleanIndex, BoundaryOrder, ByteIndex, Index, NativeIndex, PageIndex, PageLocation, +}; +use polars_parquet::parquet::read::{read_columns_indexes, read_metadata, read_pages_locations}; +use polars_parquet::parquet::schema::types::{ + FieldInfo, PhysicalType, PrimitiveConvertedType, PrimitiveLogicalType, PrimitiveType, +}; +use polars_parquet::parquet::schema::Repetition; + +/* +import pyspark.sql # 3.2.1 +spark = pyspark.sql.SparkSession.builder.getOrCreate() +spark.conf.set("parquet.bloom.filter.enabled", True) +spark.conf.set("parquet.bloom.filter.expected.ndv", 10) +spark.conf.set("parquet.bloom.filter.max.bytes", 32) + +data = [(i, f"{i}", False) for i in range(10)] +df = spark.createDataFrame(data, ["id", "string", "bool"]).repartition(1) + +df.write.parquet("bla.parquet", mode = "overwrite") +*/ +const FILE: &[u8] = &[ + 80, 65, 82, 49, 21, 0, 21, 172, 1, 21, 138, 1, 21, 169, 161, 209, 137, 5, 28, 21, 20, 21, 0, + 21, 6, 21, 8, 0, 0, 86, 24, 2, 0, 0, 0, 20, 1, 0, 13, 1, 17, 9, 1, 22, 1, 1, 0, 3, 1, 5, 12, 0, + 0, 0, 4, 1, 5, 12, 0, 0, 0, 5, 1, 5, 12, 0, 0, 0, 6, 1, 5, 12, 0, 0, 0, 7, 1, 5, 72, 0, 0, 0, + 8, 0, 0, 0, 0, 0, 0, 0, 9, 0, 0, 0, 0, 0, 0, 0, 21, 0, 21, 112, 21, 104, 21, 138, 239, 232, + 170, 15, 28, 21, 20, 21, 0, 21, 6, 21, 8, 0, 0, 56, 40, 2, 0, 0, 0, 20, 1, 1, 0, 0, 0, 48, 1, + 5, 0, 49, 1, 5, 0, 50, 1, 5, 0, 51, 1, 5, 0, 52, 1, 5, 0, 53, 1, 5, 60, 54, 1, 0, 0, 0, 55, 1, + 0, 0, 0, 56, 1, 0, 0, 0, 57, 21, 0, 21, 16, 21, 20, 21, 202, 209, 169, 227, 4, 28, 21, 20, 21, + 0, 21, 6, 21, 8, 0, 0, 8, 28, 2, 0, 0, 0, 20, 1, 0, 0, 25, 17, 2, 25, 24, 8, 0, 0, 0, 0, 0, 0, + 0, 0, 25, 24, 8, 9, 0, 0, 0, 0, 0, 0, 0, 21, 2, 25, 22, 0, 0, 25, 17, 2, 25, 24, 1, 48, 25, 24, + 1, 57, 21, 2, 25, 22, 0, 0, 25, 17, 2, 25, 24, 1, 0, 25, 24, 1, 0, 21, 2, 25, 22, 0, 0, 25, 28, + 22, 8, 21, 188, 1, 22, 0, 0, 0, 25, 28, 22, 196, 1, 21, 150, 1, 22, 0, 0, 0, 25, 28, 22, 218, + 2, 21, 66, 22, 0, 0, 0, 21, 64, 28, 28, 0, 0, 28, 28, 0, 0, 28, 28, 0, 0, 0, 24, 130, 24, 8, + 134, 8, 68, 6, 2, 101, 128, 10, 64, 2, 38, 78, 114, 1, 64, 38, 1, 192, 194, 152, 64, 70, 0, 36, + 56, 121, 64, 0, 21, 64, 28, 28, 0, 0, 28, 28, 0, 0, 28, 28, 0, 0, 0, 8, 17, 10, 29, 5, 88, 194, + 0, 35, 208, 25, 16, 70, 68, 48, 38, 17, 16, 140, 68, 98, 56, 0, 131, 4, 193, 40, 129, 161, 160, + 1, 96, 21, 64, 28, 28, 0, 0, 28, 28, 0, 0, 28, 28, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 21, 2, 25, 76, 72, 12, 115, 112, + 97, 114, 107, 95, 115, 99, 104, 101, 109, 97, 21, 6, 0, 21, 4, 37, 2, 24, 2, 105, 100, 0, 21, + 12, 37, 2, 24, 6, 115, 116, 114, 105, 110, 103, 37, 0, 76, 28, 0, 0, 0, 21, 0, 37, 2, 24, 4, + 98, 111, 111, 108, 0, 22, 20, 25, 28, 25, 60, 38, 8, 28, 21, 4, 25, 53, 0, 6, 8, 25, 24, 2, + 105, 100, 21, 2, 22, 20, 22, 222, 1, 22, 188, 1, 38, 8, 60, 24, 8, 9, 0, 0, 0, 0, 0, 0, 0, 24, + 8, 0, 0, 0, 0, 0, 0, 0, 0, 22, 0, 40, 8, 9, 0, 0, 0, 0, 0, 0, 0, 24, 8, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 25, 28, 21, 0, 21, 0, 21, 2, 0, 22, 226, 4, 0, 22, 158, 4, 21, 22, 22, 156, 3, 21, 62, 0, + 38, 196, 1, 28, 21, 12, 25, 53, 0, 6, 8, 25, 24, 6, 115, 116, 114, 105, 110, 103, 21, 2, 22, + 20, 22, 158, 1, 22, 150, 1, 38, 196, 1, 60, 54, 0, 40, 1, 57, 24, 1, 48, 0, 25, 28, 21, 0, 21, + 0, 21, 2, 0, 22, 192, 5, 0, 22, 180, 4, 21, 24, 22, 218, 3, 21, 34, 0, 38, 218, 2, 28, 21, 0, + 25, 53, 0, 6, 8, 25, 24, 4, 98, 111, 111, 108, 21, 2, 22, 20, 22, 62, 22, 66, 38, 218, 2, 60, + 24, 1, 0, 24, 1, 0, 22, 0, 40, 1, 0, 24, 1, 0, 0, 25, 28, 21, 0, 21, 0, 21, 2, 0, 22, 158, 6, + 0, 22, 204, 4, 21, 22, 22, 252, 3, 21, 34, 0, 22, 186, 3, 22, 20, 38, 8, 22, 148, 3, 20, 0, 0, + 25, 44, 24, 24, 111, 114, 103, 46, 97, 112, 97, 99, 104, 101, 46, 115, 112, 97, 114, 107, 46, + 118, 101, 114, 115, 105, 111, 110, 24, 5, 51, 46, 50, 46, 49, 0, 24, 41, 111, 114, 103, 46, 97, + 112, 97, 99, 104, 101, 46, 115, 112, 97, 114, 107, 46, 115, 113, 108, 46, 112, 97, 114, 113, + 117, 101, 116, 46, 114, 111, 119, 46, 109, 101, 116, 97, 100, 97, 116, 97, 24, 213, 1, 123, 34, + 116, 121, 112, 101, 34, 58, 34, 115, 116, 114, 117, 99, 116, 34, 44, 34, 102, 105, 101, 108, + 100, 115, 34, 58, 91, 123, 34, 110, 97, 109, 101, 34, 58, 34, 105, 100, 34, 44, 34, 116, 121, + 112, 101, 34, 58, 34, 108, 111, 110, 103, 34, 44, 34, 110, 117, 108, 108, 97, 98, 108, 101, 34, + 58, 116, 114, 117, 101, 44, 34, 109, 101, 116, 97, 100, 97, 116, 97, 34, 58, 123, 125, 125, 44, + 123, 34, 110, 97, 109, 101, 34, 58, 34, 115, 116, 114, 105, 110, 103, 34, 44, 34, 116, 121, + 112, 101, 34, 58, 34, 115, 116, 114, 105, 110, 103, 34, 44, 34, 110, 117, 108, 108, 97, 98, + 108, 101, 34, 58, 116, 114, 117, 101, 44, 34, 109, 101, 116, 97, 100, 97, 116, 97, 34, 58, 123, + 125, 125, 44, 123, 34, 110, 97, 109, 101, 34, 58, 34, 98, 111, 111, 108, 34, 44, 34, 116, 121, + 112, 101, 34, 58, 34, 98, 111, 111, 108, 101, 97, 110, 34, 44, 34, 110, 117, 108, 108, 97, 98, + 108, 101, 34, 58, 116, 114, 117, 101, 44, 34, 109, 101, 116, 97, 100, 97, 116, 97, 34, 58, 123, + 125, 125, 93, 125, 0, 24, 74, 112, 97, 114, 113, 117, 101, 116, 45, 109, 114, 32, 118, 101, + 114, 115, 105, 111, 110, 32, 49, 46, 49, 50, 46, 50, 32, 40, 98, 117, 105, 108, 100, 32, 55, + 55, 101, 51, 48, 99, 56, 48, 57, 51, 51, 56, 54, 101, 99, 53, 50, 99, 51, 99, 102, 97, 54, 99, + 51, 52, 98, 55, 101, 102, 51, 51, 50, 49, 51, 50, 50, 99, 57, 52, 41, 25, 60, 28, 0, 0, 28, 0, + 0, 28, 0, 0, 0, 182, 2, 0, 0, 80, 65, 82, 49, +]; + +#[test] +fn test() -> Result<(), Error> { + let mut reader = std::io::Cursor::new(FILE); + + let expected_index = vec![ + Box::new(NativeIndex:: { + primitive_type: PrimitiveType::from_physical("id".to_string(), PhysicalType::Int64), + indexes: vec![PageIndex { + min: Some(0), + max: Some(9), + null_count: Some(0), + }], + boundary_order: BoundaryOrder::Ascending, + }) as Box, + Box::new(ByteIndex { + primitive_type: PrimitiveType { + field_info: FieldInfo { + name: "string".to_string(), + repetition: Repetition::Optional, + id: None, + }, + logical_type: Some(PrimitiveLogicalType::String), + converted_type: Some(PrimitiveConvertedType::Utf8), + physical_type: PhysicalType::ByteArray, + }, + indexes: vec![PageIndex { + min: Some(b"0".to_vec()), + max: Some(b"9".to_vec()), + null_count: Some(0), + }], + boundary_order: BoundaryOrder::Ascending, + }), + Box::new(BooleanIndex { + indexes: vec![PageIndex { + min: Some(false), + max: Some(false), + null_count: Some(0), + }], + boundary_order: BoundaryOrder::Ascending, + }), + ]; + let expected_page_locations = vec![ + vec![PageLocation { + offset: 4, + compressed_page_size: 94, + first_row_index: 0, + }], + vec![PageLocation { + offset: 98, + compressed_page_size: 75, + first_row_index: 0, + }], + vec![PageLocation { + offset: 173, + compressed_page_size: 33, + first_row_index: 0, + }], + ]; + + let metadata = read_metadata(&mut reader)?; + let columns = &metadata.row_groups[0].columns(); + + let indexes = read_columns_indexes(&mut reader, columns)?; + assert_eq!(&indexes, &expected_index); + + let pages = read_pages_locations(&mut reader, columns)?; + assert_eq!(pages, expected_page_locations); + + Ok(()) +} diff --git a/crates/polars/tests/it/io/parquet/read/mod.rs b/crates/polars/tests/it/io/parquet/read/mod.rs new file mode 100644 index 0000000000000..49ec00fa2e0ff --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/mod.rs @@ -0,0 +1,432 @@ +/// Serialization to Rust's Native types. +/// In comparison to Arrow, this in-memory format does not leverage logical types nor SIMD operations, +/// but OTOH it has no external dependencies and is very familiar to Rust developers. +mod binary; +mod boolean; +mod deserialize; +mod dictionary; +mod fixed_binary; +mod indexes; +mod primitive; +mod primitive_nested; +mod struct_; +mod utils; + +use std::fs::File; + +use dictionary::{deserialize as deserialize_dict, DecodedDictPage}; +#[cfg(feature = "async")] +use futures::StreamExt; +use polars_parquet::parquet::error::{Error, Result}; +use polars_parquet::parquet::metadata::ColumnChunkMetaData; +use polars_parquet::parquet::page::{CompressedPage, DataPage, Page}; +#[cfg(feature = "async")] +use polars_parquet::parquet::read::get_page_stream; +#[cfg(feature = "async")] +use polars_parquet::parquet::read::read_metadata_async; +use polars_parquet::parquet::read::{ + get_column_iterator, get_field_columns, read_metadata, BasicDecompressor, MutStreamingIterator, + State, +}; +use polars_parquet::parquet::schema::types::{GroupConvertedType, ParquetType}; +use polars_parquet::parquet::schema::Repetition; +use polars_parquet::parquet::types::int96_to_i64_ns; +use polars_parquet::parquet::FallibleStreamingIterator; + +use super::*; + +pub fn get_path() -> PathBuf { + let dir = env!("CARGO_MANIFEST_DIR"); + PathBuf::from(dir).join("../../docs/data") +} + +/// Reads a page into an [`Array`]. +/// This is CPU-intensive: decompress, decode and de-serialize. +pub fn page_to_array(page: &DataPage, dict: Option<&DecodedDictPage>) -> Result { + let physical_type = page.descriptor.primitive_type.physical_type; + match page.descriptor.max_rep_level { + 0 => match physical_type { + PhysicalType::Boolean => Ok(Array::Boolean(boolean::page_to_vec(page)?)), + PhysicalType::Int32 => { + let dict = dict.map(|dict| { + if let DecodedDictPage::Int32(dict) = dict { + dict + } else { + panic!() + } + }); + primitive::page_to_vec(page, dict).map(Array::Int32) + }, + PhysicalType::Int64 => { + let dict = dict.map(|dict| { + if let DecodedDictPage::Int64(dict) = dict { + dict + } else { + panic!() + } + }); + primitive::page_to_vec(page, dict).map(Array::Int64) + }, + PhysicalType::Int96 => { + let dict = dict.map(|dict| { + if let DecodedDictPage::Int96(dict) = dict { + dict + } else { + panic!() + } + }); + primitive::page_to_vec(page, dict).map(Array::Int96) + }, + PhysicalType::Float => { + let dict = dict.map(|dict| { + if let DecodedDictPage::Float(dict) = dict { + dict + } else { + panic!() + } + }); + primitive::page_to_vec(page, dict).map(Array::Float) + }, + PhysicalType::Double => { + let dict = dict.map(|dict| { + if let DecodedDictPage::Double(dict) = dict { + dict + } else { + panic!() + } + }); + primitive::page_to_vec(page, dict).map(Array::Double) + }, + PhysicalType::ByteArray => { + let dict = dict.map(|dict| { + if let DecodedDictPage::ByteArray(dict) = dict { + dict + } else { + panic!() + } + }); + + binary::page_to_vec(page, dict).map(Array::Binary) + }, + PhysicalType::FixedLenByteArray(_) => { + let dict = dict.map(|dict| { + if let DecodedDictPage::FixedLenByteArray(dict) = dict { + dict + } else { + panic!() + } + }); + + fixed_binary::page_to_vec(page, dict).map(Array::FixedLenBinary) + }, + }, + _ => match dict { + None => match physical_type { + PhysicalType::Int64 => Ok(primitive_nested::page_to_array::(page, None)?), + _ => todo!(), + }, + Some(_) => match physical_type { + PhysicalType::Int64 => { + let dict = dict.map(|dict| { + if let DecodedDictPage::Int64(dict) = dict { + dict + } else { + panic!() + } + }); + Ok(primitive_nested::page_dict_to_array(page, dict)?) + }, + _ => todo!(), + }, + }, + } +} + +pub fn collect>( + mut iterator: I, + type_: PhysicalType, +) -> Result> { + let mut arrays = vec![]; + let mut dict = None; + while let Some(page) = iterator.next()? { + match page { + Page::Data(page) => arrays.push(page_to_array(page, dict.as_ref())?), + Page::Dict(page) => { + dict = Some(deserialize_dict(page, type_)?); + }, + } + } + Ok(arrays) +} + +/// Reads columns into an [`Array`]. +/// This is CPU-intensive: decompress, decode and de-serialize. +pub fn columns_to_array(mut columns: I, field: &ParquetType) -> Result +where + II: Iterator>, + I: MutStreamingIterator, +{ + let mut validity = vec![]; + let mut has_filled = false; + let mut arrays = vec![]; + while let State::Some(mut new_iter) = columns.advance()? { + if let Some((pages, column)) = new_iter.get() { + let mut iterator = BasicDecompressor::new(pages, vec![]); + + let mut dict = None; + while let Some(page) = iterator.next()? { + match page { + polars_parquet::parquet::page::Page::Data(page) => { + if !has_filled { + struct_::extend_validity(&mut validity, page)?; + } + arrays.push(page_to_array(page, dict.as_ref())?) + }, + polars_parquet::parquet::page::Page::Dict(page) => { + dict = Some(deserialize_dict(page, column.physical_type())?); + }, + } + } + } + has_filled = true; + columns = new_iter; + } + + match field { + ParquetType::PrimitiveType { .. } => { + arrays.pop().ok_or_else(|| Error::OutOfSpec("".to_string())) + }, + ParquetType::GroupType { converted_type, .. } => { + if let Some(converted_type) = converted_type { + match converted_type { + GroupConvertedType::List => Ok(arrays.pop().unwrap()), + _ => todo!(), + } + } else { + Ok(Array::Struct(arrays, validity)) + } + }, + } +} + +pub fn read_column( + reader: &mut R, + row_group: usize, + field_name: &str, +) -> Result<(Array, Option>)> { + let metadata = read_metadata(reader)?; + + let field = metadata + .schema() + .fields() + .iter() + .find(|field| field.name() == field_name) + .ok_or_else(|| Error::OutOfSpec("column does not exist".to_string()))?; + + let columns = get_column_iterator( + reader, + &metadata.row_groups[row_group], + field.name(), + None, + vec![], + usize::MAX, + ); + + let mut statistics = get_field_columns(metadata.row_groups[row_group].columns(), field.name()) + .map(|column_meta| column_meta.statistics().transpose()) + .collect::>>()?; + + let array = columns_to_array(columns, field)?; + + Ok((array, statistics.pop().unwrap())) +} + +#[cfg(feature = "async")] +pub async fn read_column_async< + R: futures::AsyncRead + futures::AsyncSeek + Send + std::marker::Unpin, +>( + reader: &mut R, + row_group: usize, + field_name: &str, +) -> Result<(Array, Option>)> { + let metadata = read_metadata_async(reader).await?; + + let field = metadata + .schema() + .fields() + .iter() + .find(|field| field.name() == field_name) + .ok_or_else(|| Error::OutOfSpec("column does not exist".to_string()))?; + + let column = get_field_columns(metadata.row_groups[row_group].columns(), field.name()) + .next() + .unwrap(); + + let pages = get_page_stream(column, reader, vec![], Arc::new(|_, _| true), usize::MAX).await?; + + let mut statistics = get_field_columns(metadata.row_groups[row_group].columns(), field.name()) + .map(|column_meta| column_meta.statistics().transpose()) + .collect::>>()?; + + let pages = pages.collect::>().await; + + let iterator = BasicDecompressor::new(pages.into_iter(), vec![]); + + let mut arrays = collect(iterator, column.physical_type())?; + + Ok((arrays.pop().unwrap(), statistics.pop().unwrap())) +} + +fn get_column(path: &str, column: &str) -> Result<(Array, Option>)> { + let mut file = File::open(path).unwrap(); + read_column(&mut file, 0, column) +} + +fn test_column(column: &str) -> Result<()> { + let mut path = get_path(); + path.push("alltypes_plain.parquet"); + let path = path.to_str().unwrap(); + let (result, statistics) = get_column(path, column)?; + // the file does not have statistics + assert_eq!(statistics.as_ref().map(|x| x.as_ref()), None); + assert_eq!(result, alltypes_plain(column)); + Ok(()) +} + +#[test] +fn int32() -> Result<()> { + test_column("id") +} + +#[test] +fn bool() -> Result<()> { + test_column("bool_col") +} + +#[test] +fn tinyint_col() -> Result<()> { + test_column("tinyint_col") +} + +#[test] +fn smallint_col() -> Result<()> { + test_column("smallint_col") +} + +#[test] +fn int_col() -> Result<()> { + test_column("int_col") +} + +#[test] +fn bigint_col() -> Result<()> { + test_column("bigint_col") +} + +#[test] +fn float_col() -> Result<()> { + test_column("float_col") +} + +#[test] +fn double_col() -> Result<()> { + test_column("double_col") +} + +#[test] +fn timestamp_col() -> Result<()> { + let mut path = get_path(); + path.push("alltypes_plain.parquet"); + let path = path.to_str().unwrap(); + + let expected = vec![ + 1235865600000000000i64, + 1235865660000000000, + 1238544000000000000, + 1238544060000000000, + 1233446400000000000, + 1233446460000000000, + 1230768000000000000, + 1230768060000000000, + ]; + + let expected = expected.into_iter().map(Some).collect::>(); + let (array, _) = get_column(path, "timestamp_col")?; + if let Array::Int96(array) = array { + let a = array + .into_iter() + .map(|x| x.map(int96_to_i64_ns)) + .collect::>(); + assert_eq!(expected, a); + } else { + panic!("Timestamp expected"); + }; + Ok(()) +} + +#[test] +fn test_metadata() -> Result<()> { + let mut testdata = get_path(); + testdata.push("alltypes_plain.parquet"); + let mut file = File::open(testdata).unwrap(); + + let metadata = read_metadata(&mut file)?; + + let columns = metadata.schema_descr.columns(); + + /* + from pyarrow: + required group field_id=0 schema { + optional int32 field_id=1 id; + optional boolean field_id=2 bool_col; + optional int32 field_id=3 tinyint_col; + optional int32 field_id=4 smallint_col; + optional int32// pub enum Value { + // UInt32(Option), + // Int32(Option), + // Int64(Option), + // Int96(Option<[u32; 3]>), + // Float32(Option), + // Float64(Option), + // Boolean(Option), + // Binary(Option>), + // FixedLenBinary(Option>), + // List(Option), + // } + field_id=5 int_col; + optional int64 field_id=6 bigint_col; + optional float field_id=7 float_col; + optional double field_id=8 double_col; + optional binary field_id=9 date_string_col; + optional binary field_id=10 string_col; + optional int96 field_id=11 timestamp_col; + } + */ + let expected = vec![ + PhysicalType::Int32, + PhysicalType::Boolean, + PhysicalType::Int32, + PhysicalType::Int32, + PhysicalType::Int32, + PhysicalType::Int64, + PhysicalType::Float, + PhysicalType::Double, + PhysicalType::ByteArray, + PhysicalType::ByteArray, + PhysicalType::Int96, + ]; + + let result = columns + .iter() + .map(|column| { + assert_eq!( + column.descriptor.primitive_type.field_info.repetition, + Repetition::Optional + ); + column.descriptor.primitive_type.physical_type + }) + .collect::>(); + + assert_eq!(expected, result); + Ok(()) +} diff --git a/crates/polars/tests/it/io/parquet/read/primitive.rs b/crates/polars/tests/it/io/parquet/read/primitive.rs new file mode 100644 index 0000000000000..6d40d0dc92809 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/primitive.rs @@ -0,0 +1,114 @@ +use polars_parquet::parquet::deserialize::{ + native_cast, Casted, HybridRleDecoderIter, HybridRleIter, NativePageState, OptionalValues, + SliceFilteredIter, +}; +use polars_parquet::parquet::encoding::hybrid_rle::Decoder; +use polars_parquet::parquet::encoding::Encoding; +use polars_parquet::parquet::error::Error; +use polars_parquet::parquet::page::{split_buffer, DataPage}; +use polars_parquet::parquet::schema::Repetition; +use polars_parquet::parquet::types::NativeType; + +use super::dictionary::PrimitivePageDict; +use super::utils::deserialize_optional; + +/// The deserialization state of a `DataPage` of `Primitive` parquet primitive type +#[derive(Debug)] +pub enum FilteredPageState<'a, T> +where + T: NativeType, +{ + /// A page of optional values + Optional(SliceFilteredIter, Casted<'a, T>>>), + /// A page of required values + Required(SliceFilteredIter>), +} + +/// The deserialization state of a `DataPage` of `Primitive` parquet primitive type +#[derive(Debug)] +#[allow(clippy::large_enum_variant)] +pub enum PageState<'a, T> +where + T: NativeType, +{ + Nominal(NativePageState<'a, T, &'a PrimitivePageDict>), + Filtered(FilteredPageState<'a, T>), +} + +impl<'a, T: NativeType> PageState<'a, T> { + /// Tries to create [`NativePageState`] + /// # Error + /// Errors iff the page is not a `NativePageState` + pub fn try_new( + page: &'a DataPage, + dict: Option<&'a PrimitivePageDict>, + ) -> Result { + if let Some(selected_rows) = page.selected_rows() { + let is_optional = + page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; + + match (page.encoding(), dict, is_optional) { + (Encoding::Plain, _, true) => { + let (_, def_levels, _) = split_buffer(page)?; + + let validity = HybridRleDecoderIter::new(HybridRleIter::new( + Decoder::new(def_levels, 1), + page.num_values(), + )); + let values = native_cast(page)?; + + // validity and values interleaved. + let values = OptionalValues::new(validity, values); + + let values = + SliceFilteredIter::new(values, selected_rows.iter().copied().collect()); + + Ok(Self::Filtered(FilteredPageState::Optional(values))) + }, + (Encoding::Plain, _, false) => { + let values = SliceFilteredIter::new( + native_cast(page)?, + selected_rows.iter().copied().collect(), + ); + Ok(Self::Filtered(FilteredPageState::Required(values))) + }, + _ => Err(Error::FeatureNotSupported(format!( + "Viewing page for encoding {:?} for native type {}", + page.encoding(), + std::any::type_name::() + ))), + } + } else { + NativePageState::try_new(page, dict).map(Self::Nominal) + } + } +} + +pub fn page_to_vec( + page: &DataPage, + dict: Option<&PrimitivePageDict>, +) -> Result>, Error> { + assert_eq!(page.descriptor.max_rep_level, 0); + let state = PageState::::try_new(page, dict)?; + + match state { + PageState::Nominal(state) => match state { + NativePageState::Optional(validity, mut values) => { + deserialize_optional(validity, values.by_ref().map(Ok)) + }, + NativePageState::Required(values) => Ok(values.map(Some).collect()), + NativePageState::RequiredDictionary(dict) => dict + .indexes + .map(|x| dict.dict.value(x as usize).copied().map(Some)) + .collect(), + NativePageState::OptionalDictionary(validity, dict) => { + let values = dict.indexes.map(|x| dict.dict.value(x as usize).copied()); + deserialize_optional(validity, values) + }, + }, + PageState::Filtered(state) => match state { + FilteredPageState::Optional(values) => Ok(values.collect()), + FilteredPageState::Required(values) => Ok(values.map(Some).collect()), + }, + } +} diff --git a/crates/polars/tests/it/io/parquet/read/primitive_nested.rs b/crates/polars/tests/it/io/parquet/read/primitive_nested.rs new file mode 100644 index 0000000000000..f8036b3fe17f4 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/primitive_nested.rs @@ -0,0 +1,224 @@ +use polars_parquet::parquet::encoding::hybrid_rle::HybridRleDecoder; +use polars_parquet::parquet::encoding::{bitpacked, uleb128, Encoding}; +use polars_parquet::parquet::error::Error; +use polars_parquet::parquet::page::{split_buffer, DataPage}; +use polars_parquet::parquet::read::levels::get_bit_width; +use polars_parquet::parquet::types::NativeType; + +use super::dictionary::PrimitivePageDict; +use super::Array; + +fn read_buffer(values: &[u8]) -> impl Iterator + '_ { + let chunks = values.chunks_exact(std::mem::size_of::()); + chunks.map(|chunk| { + // unwrap is infalible due to the chunk size. + let chunk: T::Bytes = match chunk.try_into() { + Ok(v) => v, + Err(_) => panic!(), + }; + T::from_le_bytes(chunk) + }) +} + +// todo: generalize i64 -> T +fn compose_array, F: Iterator, G: Iterator>( + rep_levels: I, + def_levels: F, + max_rep: u32, + max_def: u32, + mut values: G, +) -> Result { + let mut outer = vec![]; + let mut inner = vec![]; + + assert_eq!(max_rep, 1); + assert_eq!(max_def, 3); + let mut prev_def = 0; + rep_levels + .into_iter() + .zip(def_levels.into_iter()) + .try_for_each(|(rep, def)| { + match rep { + 1 => {}, + 0 => { + if prev_def > 1 { + let old = std::mem::take(&mut inner); + outer.push(Some(Array::Int64(old))); + } + }, + _ => unreachable!(), + } + match def { + 3 => inner.push(Some(values.next().unwrap())), + 2 => inner.push(None), + 1 => outer.push(Some(Array::Int64(vec![]))), + 0 => outer.push(None), + _ => unreachable!(), + } + prev_def = def; + Ok::<(), Error>(()) + })?; + outer.push(Some(Array::Int64(inner))); + Ok(Array::List(outer)) +} + +fn read_array_impl>( + rep_levels: &[u8], + def_levels: &[u8], + values: I, + length: usize, + rep_level_encoding: (&Encoding, i16), + def_level_encoding: (&Encoding, i16), +) -> Result { + let max_rep_level = rep_level_encoding.1 as u32; + let max_def_level = def_level_encoding.1 as u32; + + match ( + (rep_level_encoding.0, max_rep_level == 0), + (def_level_encoding.0, max_def_level == 0), + ) { + ((Encoding::Rle, true), (Encoding::Rle, true)) => compose_array( + std::iter::repeat(0).take(length), + std::iter::repeat(0).take(length), + max_rep_level, + max_def_level, + values, + ), + ((Encoding::Rle, false), (Encoding::Rle, true)) => { + let num_bits = get_bit_width(rep_level_encoding.1); + let rep_levels = HybridRleDecoder::try_new(rep_levels, num_bits, length)?; + compose_array( + rep_levels, + std::iter::repeat(0).take(length), + max_rep_level, + max_def_level, + values, + ) + }, + ((Encoding::Rle, true), (Encoding::Rle, false)) => { + let num_bits = get_bit_width(def_level_encoding.1); + let def_levels = HybridRleDecoder::try_new(def_levels, num_bits, length)?; + compose_array( + std::iter::repeat(0).take(length), + def_levels, + max_rep_level, + max_def_level, + values, + ) + }, + ((Encoding::Rle, false), (Encoding::Rle, false)) => { + let rep_levels = + HybridRleDecoder::try_new(rep_levels, get_bit_width(rep_level_encoding.1), length)?; + let def_levels = + HybridRleDecoder::try_new(def_levels, get_bit_width(def_level_encoding.1), length)?; + compose_array(rep_levels, def_levels, max_rep_level, max_def_level, values) + }, + _ => todo!(), + } +} + +fn read_array( + rep_levels: &[u8], + def_levels: &[u8], + values: &[u8], + length: u32, + rep_level_encoding: (&Encoding, i16), + def_level_encoding: (&Encoding, i16), +) -> Result { + let values = read_buffer::(values); + read_array_impl::<_>( + rep_levels, + def_levels, + values, + length as usize, + rep_level_encoding, + def_level_encoding, + ) +} + +pub fn page_to_array( + page: &DataPage, + dict: Option<&PrimitivePageDict>, +) -> Result { + let (rep_levels, def_levels, values) = split_buffer(page)?; + + match (&page.encoding(), dict) { + (Encoding::Plain, None) => read_array( + rep_levels, + def_levels, + values, + page.num_values() as u32, + ( + &page.repetition_level_encoding(), + page.descriptor.max_rep_level, + ), + ( + &page.definition_level_encoding(), + page.descriptor.max_def_level, + ), + ), + _ => todo!(), + } +} + +fn read_dict_array( + rep_levels: &[u8], + def_levels: &[u8], + values: &[u8], + length: u32, + dict: &PrimitivePageDict, + rep_level_encoding: (&Encoding, i16), + def_level_encoding: (&Encoding, i16), +) -> Result { + let dict_values = dict.values(); + + let bit_width = values[0]; + let values = &values[1..]; + + let (_, consumed) = uleb128::decode(values)?; + let values = &values[consumed..]; + + let indices = bitpacked::Decoder::::try_new(values, bit_width as usize, length as usize)?; + + let values = indices.map(|id| dict_values[id as usize]); + + read_array_impl::<_>( + rep_levels, + def_levels, + values, + length as usize, + rep_level_encoding, + def_level_encoding, + ) +} + +pub fn page_dict_to_array( + page: &DataPage, + dict: Option<&PrimitivePageDict>, +) -> Result { + assert_eq!(page.descriptor.max_rep_level, 1); + + let (rep_levels, def_levels, values) = split_buffer(page)?; + + match (page.encoding(), dict) { + (Encoding::PlainDictionary, Some(dict)) => read_dict_array( + rep_levels, + def_levels, + values, + page.num_values() as u32, + dict, + ( + &page.repetition_level_encoding(), + page.descriptor.max_rep_level, + ), + ( + &page.definition_level_encoding(), + page.descriptor.max_def_level, + ), + ), + (_, None) => Err(Error::OutOfSpec( + "A dictionary-encoded page MUST be preceded by a dictionary page".to_string(), + )), + _ => todo!(), + } +} diff --git a/crates/polars/tests/it/io/parquet/read/struct_.rs b/crates/polars/tests/it/io/parquet/read/struct_.rs new file mode 100644 index 0000000000000..4e6af5b9eb970 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/struct_.rs @@ -0,0 +1,27 @@ +use polars_parquet::parquet::encoding::hybrid_rle::HybridRleDecoder; +use polars_parquet::parquet::error::Error; +use polars_parquet::parquet::page::{split_buffer, DataPage}; +use polars_parquet::parquet::read::levels::get_bit_width; + +pub fn extend_validity(val: &mut Vec, page: &DataPage) -> Result<(), Error> { + let (_, def_levels, _) = split_buffer(page)?; + let length = page.num_values(); + + if page.descriptor.max_def_level == 0 { + return Ok(()); + } + + let def_level_encoding = ( + &page.definition_level_encoding(), + page.descriptor.max_def_level, + ); + + let mut def_levels = + HybridRleDecoder::try_new(def_levels, get_bit_width(def_level_encoding.1), length)?; + + val.reserve(length); + def_levels.try_for_each(|x| { + val.push(x != 0); + Ok(()) + }) +} diff --git a/crates/polars/tests/it/io/parquet/read/utils.rs b/crates/polars/tests/it/io/parquet/read/utils.rs new file mode 100644 index 0000000000000..81492e60936e6 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/read/utils.rs @@ -0,0 +1,66 @@ +use polars_parquet::parquet::deserialize::{ + DefLevelsDecoder, HybridDecoderBitmapIter, HybridEncoded, +}; +use polars_parquet::parquet::encoding::hybrid_rle::{BitmapIter, HybridRleDecoder}; +use polars_parquet::parquet::error::Error; + +pub fn deserialize_optional>>( + validity: DefLevelsDecoder, + values: I, +) -> Result>, Error> { + match validity { + DefLevelsDecoder::Bitmap(bitmap) => deserialize_bitmap(bitmap, values), + DefLevelsDecoder::Levels(levels, max_level) => { + deserialize_levels(levels, max_level, values) + }, + } +} + +fn deserialize_bitmap>>( + mut validity: HybridDecoderBitmapIter, + mut values: I, +) -> Result>, Error> { + let mut deserialized = Vec::with_capacity(validity.len()); + + validity.try_for_each(|run| match run? { + HybridEncoded::Bitmap(bitmap, length) => { + BitmapIter::new(bitmap, 0, length).try_for_each(|x| { + if x { + deserialized.push(values.next().transpose()?); + } else { + deserialized.push(None); + } + Result::<_, Error>::Ok(()) + }) + }, + HybridEncoded::Repeated(is_set, length) => { + if is_set { + deserialized.reserve(length); + for x in values.by_ref().take(length) { + deserialized.push(Some(x?)) + } + } else { + deserialized.extend(std::iter::repeat(None).take(length)) + } + Ok(()) + }, + })?; + Ok(deserialized) +} + +fn deserialize_levels>>( + levels: HybridRleDecoder, + max: u32, + mut values: I, +) -> Result>, Error> { + levels + .into_iter() + .map(|x| { + if x == max { + values.next().transpose() + } else { + Ok(None) + } + }) + .collect() +} diff --git a/crates/polars-parquet/tests/it/roundtrip.rs b/crates/polars/tests/it/io/parquet/roundtrip.rs similarity index 100% rename from crates/polars-parquet/tests/it/roundtrip.rs rename to crates/polars/tests/it/io/parquet/roundtrip.rs diff --git a/crates/polars/tests/it/io/parquet/write/binary.rs b/crates/polars/tests/it/io/parquet/write/binary.rs new file mode 100644 index 0000000000000..add477530fecd --- /dev/null +++ b/crates/polars/tests/it/io/parquet/write/binary.rs @@ -0,0 +1,87 @@ +use polars_parquet::parquet::encoding::hybrid_rle::encode_bool; +use polars_parquet::parquet::encoding::Encoding; +use polars_parquet::parquet::error::Result; +use polars_parquet::parquet::metadata::Descriptor; +use polars_parquet::parquet::page::{DataPage, DataPageHeader, DataPageHeaderV1, Page}; +use polars_parquet::parquet::statistics::{serialize_statistics, BinaryStatistics, Statistics}; +use polars_parquet::parquet::types::ord_binary; +use polars_parquet::parquet::write::WriteOptions; + +fn unzip_option(array: &[Option>]) -> Result<(Vec, Vec)> { + // leave the first 4 bytes anouncing the length of the def level + // this will be overwritten at the end, once the length is known. + // This is unknown at this point because of the uleb128 encoding, + // whose length is variable. + let mut validity = std::io::Cursor::new(vec![0; 4]); + validity.set_position(4); + + let mut values = vec![]; + let iter = array.iter().map(|value| { + if let Some(item) = value { + values.extend_from_slice(&(item.len() as i32).to_le_bytes()); + values.extend_from_slice(item.as_ref()); + true + } else { + false + } + }); + encode_bool(&mut validity, iter)?; + + // write the length, now that it is known + let mut validity = validity.into_inner(); + let length = validity.len() - 4; + // todo: pay this small debt (loop?) + let length = length.to_le_bytes(); + validity[0] = length[0]; + validity[1] = length[1]; + validity[2] = length[2]; + validity[3] = length[3]; + + Ok((values, validity)) +} + +pub fn array_to_page_v1( + array: &[Option>], + options: &WriteOptions, + descriptor: &Descriptor, +) -> Result { + let (values, mut buffer) = unzip_option(array)?; + + buffer.extend_from_slice(&values); + + let statistics = if options.write_statistics { + let statistics = &BinaryStatistics { + primitive_type: descriptor.primitive_type.clone(), + null_count: Some((array.len() - array.iter().flatten().count()) as i64), + distinct_count: None, + max_value: array + .iter() + .flatten() + .max_by(|x, y| ord_binary(x, y)) + .cloned(), + min_value: array + .iter() + .flatten() + .min_by(|x, y| ord_binary(x, y)) + .cloned(), + } as &dyn Statistics; + Some(serialize_statistics(statistics)) + } else { + None + }; + + let header = DataPageHeaderV1 { + num_values: array.len() as i32, + encoding: Encoding::Plain.into(), + definition_level_encoding: Encoding::Rle.into(), + repetition_level_encoding: Encoding::Rle.into(), + statistics, + }; + + Ok(Page::Data(DataPage::new( + DataPageHeader::V1(header), + buffer, + descriptor.clone(), + Some(array.len()), + ))) +} diff --git a/crates/polars/tests/it/io/parquet/write/indexes.rs b/crates/polars/tests/it/io/parquet/write/indexes.rs new file mode 100644 index 0000000000000..44c13b55ca501 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/write/indexes.rs @@ -0,0 +1,133 @@ +use std::io::Cursor; + +use polars_parquet::parquet::compression::CompressionOptions; +use polars_parquet::parquet::error::Result; +use polars_parquet::parquet::indexes::{ + select_pages, BoundaryOrder, Index, Interval, NativeIndex, PageIndex, PageLocation, +}; +use polars_parquet::parquet::metadata::SchemaDescriptor; +use polars_parquet::parquet::read::{ + read_columns_indexes, read_metadata, read_pages_locations, BasicDecompressor, IndexedPageReader, +}; +use polars_parquet::parquet::schema::types::{ParquetType, PhysicalType, PrimitiveType}; +use polars_parquet::parquet::write::{ + Compressor, DynIter, DynStreamingIterator, FileWriter, Version, WriteOptions, +}; + +use super::super::read::collect; +use super::primitive::array_to_page_v1; +use super::Array; + +fn write_file() -> Result> { + let page1 = vec![Some(0), Some(1), None, Some(3), Some(4), Some(5), Some(6)]; + let page2 = vec![Some(10), Some(11)]; + + let options = WriteOptions { + write_statistics: true, + version: Version::V1, + }; + + let schema = SchemaDescriptor::new( + "schema".to_string(), + vec![ParquetType::from_physical( + "col1".to_string(), + PhysicalType::Int32, + )], + ); + + let pages = vec![ + array_to_page_v1::(&page1, &options, &schema.columns()[0].descriptor), + array_to_page_v1::(&page2, &options, &schema.columns()[0].descriptor), + ]; + + let pages = DynStreamingIterator::new(Compressor::new( + DynIter::new(pages.into_iter()), + CompressionOptions::Uncompressed, + vec![], + )); + let columns = std::iter::once(Ok(pages)); + + let writer = Cursor::new(vec![]); + let mut writer = FileWriter::new(writer, schema, options, None); + + writer.write(DynIter::new(columns))?; + writer.end(None)?; + + Ok(writer.into_inner().into_inner()) +} + +#[test] +fn read_indexed_page() -> Result<()> { + let data = write_file()?; + let mut reader = Cursor::new(data); + + let metadata = read_metadata(&mut reader)?; + + let column = 0; + let columns = &metadata.row_groups[0].columns(); + + // selected the rows + let intervals = &[Interval::new(2, 2)]; + + let pages = read_pages_locations(&mut reader, columns)?; + + let pages = select_pages(intervals, &pages[column], metadata.row_groups[0].num_rows())?; + + let pages = IndexedPageReader::new(reader, &columns[column], pages, vec![], vec![]); + + let pages = BasicDecompressor::new(pages, vec![]); + + let arrays = collect(pages, columns[column].physical_type())?; + + // the second item and length 2 + assert_eq!(arrays, vec![Array::Int32(vec![None, Some(3)])]); + + Ok(()) +} + +#[test] +fn read_indexes_and_locations() -> Result<()> { + let data = write_file()?; + let mut reader = Cursor::new(data); + + let metadata = read_metadata(&mut reader)?; + + let columns = &metadata.row_groups[0].columns(); + + let expected_page_locations = vec![vec![ + PageLocation { + offset: 4, + compressed_page_size: 63, + first_row_index: 0, + }, + PageLocation { + offset: 67, + compressed_page_size: 47, + first_row_index: 7, + }, + ]]; + let expected_index = vec![Box::new(NativeIndex:: { + primitive_type: PrimitiveType::from_physical("col1".to_string(), PhysicalType::Int32), + indexes: vec![ + PageIndex { + min: Some(0), + max: Some(6), + null_count: Some(1), + }, + PageIndex { + min: Some(10), + max: Some(11), + null_count: Some(0), + }, + ], + boundary_order: BoundaryOrder::Unordered, + }) as Box]; + + let indexes = read_columns_indexes(&mut reader, columns)?; + assert_eq!(&indexes, &expected_index); + + let pages = read_pages_locations(&mut reader, columns)?; + assert_eq!(pages, expected_page_locations); + + Ok(()) +} diff --git a/crates/polars/tests/it/io/parquet/write/mod.rs b/crates/polars/tests/it/io/parquet/write/mod.rs new file mode 100644 index 0000000000000..dbb90b7a87b7e --- /dev/null +++ b/crates/polars/tests/it/io/parquet/write/mod.rs @@ -0,0 +1,289 @@ +mod binary; +mod indexes; +mod primitive; +mod sidecar; + +use std::io::{Cursor, Read, Seek}; +use std::sync::Arc; + +use polars_parquet::parquet::compression::{BrotliLevel, CompressionOptions}; +use polars_parquet::parquet::error::Result; +use polars_parquet::parquet::metadata::{Descriptor, SchemaDescriptor}; +use polars_parquet::parquet::page::Page; +use polars_parquet::parquet::schema::types::{ParquetType, PhysicalType}; +use polars_parquet::parquet::statistics::Statistics; +#[cfg(feature = "async")] +use polars_parquet::parquet::write::FileStreamer; +use polars_parquet::parquet::write::{ + Compressor, DynIter, DynStreamingIterator, FileWriter, Version, WriteOptions, +}; +use polars_parquet::read::read_metadata; +use primitive::array_to_page_v1; + +use super::{alltypes_plain, alltypes_statistics, Array}; + +pub fn array_to_page( + array: &Array, + options: &WriteOptions, + descriptor: &Descriptor, +) -> Result { + // using plain encoding format + match array { + Array::Int32(array) => primitive::array_to_page_v1(array, options, descriptor), + Array::Int64(array) => primitive::array_to_page_v1(array, options, descriptor), + Array::Int96(array) => primitive::array_to_page_v1(array, options, descriptor), + Array::Float(array) => primitive::array_to_page_v1(array, options, descriptor), + Array::Double(array) => primitive::array_to_page_v1(array, options, descriptor), + Array::Binary(array) => binary::array_to_page_v1(array, options, descriptor), + _ => todo!(), + } +} + +fn read_column(reader: &mut R) -> Result<(Array, Option>)> { + let (a, statistics) = super::read::read_column(reader, 0, "col")?; + Ok((a, statistics)) +} + +#[cfg(feature = "async")] +#[allow(dead_code)] +async fn read_column_async< + R: futures::AsyncRead + futures::AsyncSeek + Send + std::marker::Unpin, +>( + reader: &mut R, +) -> Result<(Array, Option>)> { + let (a, statistics) = super::read::read_column_async(reader, 0, "col").await?; + Ok((a, statistics)) +} + +fn test_column(column: &str, compression: CompressionOptions) -> Result<()> { + let array = alltypes_plain(column); + + let options = WriteOptions { + write_statistics: true, + version: Version::V1, + }; + + // prepare schema + let type_ = match array { + Array::Int32(_) => PhysicalType::Int32, + Array::Int64(_) => PhysicalType::Int64, + Array::Int96(_) => PhysicalType::Int96, + Array::Float(_) => PhysicalType::Float, + Array::Double(_) => PhysicalType::Double, + Array::Binary(_) => PhysicalType::ByteArray, + _ => todo!(), + }; + + let schema = SchemaDescriptor::new( + "schema".to_string(), + vec![ParquetType::from_physical("col".to_string(), type_)], + ); + + let a = schema.columns(); + + let pages = DynStreamingIterator::new(Compressor::new_from_vec( + DynIter::new(std::iter::once(array_to_page( + &array, + &options, + &a[0].descriptor, + ))), + compression, + vec![], + )); + let columns = std::iter::once(Ok(pages)); + + let writer = Cursor::new(vec![]); + let mut writer = FileWriter::new(writer, schema, options, None); + + writer.write(DynIter::new(columns))?; + writer.end(None)?; + + let data = writer.into_inner().into_inner(); + + let (result, statistics) = read_column(&mut Cursor::new(data))?; + assert_eq!(array, result); + let stats = alltypes_statistics(column); + assert_eq!( + statistics.as_ref().map(|x| x.as_ref()), + Some(stats).as_ref().map(|x| x.as_ref()) + ); + Ok(()) +} + +#[test] +fn int32() -> Result<()> { + test_column("id", CompressionOptions::Uncompressed) +} + +#[test] +fn int32_snappy() -> Result<()> { + test_column("id", CompressionOptions::Snappy) +} + +#[test] +fn int32_lz4() -> Result<()> { + test_column("id", CompressionOptions::Lz4Raw) +} + +#[test] +fn int32_lz4_short_i32_array() -> Result<()> { + test_column("id-short-array", CompressionOptions::Lz4Raw) +} + +#[test] +fn int32_brotli() -> Result<()> { + test_column( + "id", + CompressionOptions::Brotli(Some(BrotliLevel::default())), + ) +} + +#[test] +#[ignore = "Native boolean writer not yet implemented"] +fn bool() -> Result<()> { + test_column("bool_col", CompressionOptions::Uncompressed) +} + +#[test] +fn tinyint() -> Result<()> { + test_column("tinyint_col", CompressionOptions::Uncompressed) +} + +#[test] +fn smallint_col() -> Result<()> { + test_column("smallint_col", CompressionOptions::Uncompressed) +} + +#[test] +fn int_col() -> Result<()> { + test_column("int_col", CompressionOptions::Uncompressed) +} + +#[test] +fn bigint_col() -> Result<()> { + test_column("bigint_col", CompressionOptions::Uncompressed) +} + +#[test] +fn float_col() -> Result<()> { + test_column("float_col", CompressionOptions::Uncompressed) +} + +#[test] +fn double_col() -> Result<()> { + test_column("double_col", CompressionOptions::Uncompressed) +} + +#[test] +fn basic() -> Result<()> { + let array = vec![ + Some(0), + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + Some(6), + ]; + + let options = WriteOptions { + write_statistics: false, + version: Version::V1, + }; + + let schema = SchemaDescriptor::new( + "schema".to_string(), + vec![ParquetType::from_physical( + "col".to_string(), + PhysicalType::Int32, + )], + ); + + let pages = DynStreamingIterator::new(Compressor::new_from_vec( + DynIter::new(std::iter::once(array_to_page_v1( + &array, + &options, + &schema.columns()[0].descriptor, + ))), + CompressionOptions::Uncompressed, + vec![], + )); + let columns = std::iter::once(Ok(pages)); + + let writer = Cursor::new(vec![]); + let mut writer = FileWriter::new(writer, schema, options, None); + + writer.write(DynIter::new(columns))?; + writer.end(None)?; + + let data = writer.into_inner().into_inner(); + let mut reader = Cursor::new(data); + + let metadata = read_metadata(&mut reader)?; + + // validated against an equivalent array produced by pyarrow. + let expected = 51; + assert_eq!( + metadata.row_groups[0].columns()[0].uncompressed_size(), + expected + ); + + Ok(()) +} + +#[cfg(feature = "async")] +#[allow(dead_code)] +async fn test_column_async(column: &str, compression: CompressionOptions) -> Result<()> { + let array = alltypes_plain(column); + + let options = WriteOptions { + write_statistics: true, + version: Version::V1, + }; + + // prepare schema + let type_ = match array { + Array::Int32(_) => PhysicalType::Int32, + Array::Int64(_) => PhysicalType::Int64, + Array::Int96(_) => PhysicalType::Int96, + Array::Float(_) => PhysicalType::Float, + Array::Double(_) => PhysicalType::Double, + Array::Binary(_) => PhysicalType::ByteArray, + _ => todo!(), + }; + + let schema = SchemaDescriptor::new( + "schema".to_string(), + vec![ParquetType::from_physical("col".to_string(), type_)], + ); + + let a = schema.columns(); + + let pages = DynStreamingIterator::new(Compressor::new_from_vec( + DynIter::new(std::iter::once(array_to_page( + &array, + &options, + &a[0].descriptor, + ))), + compression, + vec![], + )); + let columns = std::iter::once(Ok(pages)); + + let writer = futures::io::Cursor::new(vec![]); + let mut writer = FileStreamer::new(writer, schema, options, None); + + writer.write(DynIter::new(columns)).await?; + writer.end(None).await?; + + let data = writer.into_inner().into_inner(); + + let (result, statistics) = read_column_async(&mut futures::io::Cursor::new(data)).await?; + assert_eq!(array, result); + let stats = alltypes_statistics(column); + assert_eq!( + statistics.as_ref().map(|x| x.as_ref()), + Some(stats).as_ref().map(|x| x.as_ref()) + ); + Ok(()) +} diff --git a/crates/polars/tests/it/io/parquet/write/primitive.rs b/crates/polars/tests/it/io/parquet/write/primitive.rs new file mode 100644 index 0000000000000..9cab7f0977f95 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/write/primitive.rs @@ -0,0 +1,78 @@ +use polars_parquet::parquet::encoding::hybrid_rle::encode_bool; +use polars_parquet::parquet::encoding::Encoding; +use polars_parquet::parquet::error::Result; +use polars_parquet::parquet::metadata::Descriptor; +use polars_parquet::parquet::page::{DataPage, DataPageHeader, DataPageHeaderV1, Page}; +use polars_parquet::parquet::statistics::{serialize_statistics, PrimitiveStatistics, Statistics}; +use polars_parquet::parquet::types::NativeType; +use polars_parquet::parquet::write::WriteOptions; + +fn unzip_option(array: &[Option]) -> Result<(Vec, Vec)> { + // leave the first 4 bytes anouncing the length of the def level + // this will be overwritten at the end, once the length is known. + // This is unknown at this point because of the uleb128 encoding, + // whose length is variable. + let mut validity = std::io::Cursor::new(vec![0; 4]); + validity.set_position(4); + + let mut values = vec![]; + let iter = array.iter().map(|value| { + if let Some(item) = value { + values.extend_from_slice(item.to_le_bytes().as_ref()); + true + } else { + false + } + }); + encode_bool(&mut validity, iter)?; + + // write the length, now that it is known + let mut validity = validity.into_inner(); + let length = validity.len() - 4; + // todo: pay this small debt (loop?) + let length = length.to_le_bytes(); + validity[0] = length[0]; + validity[1] = length[1]; + validity[2] = length[2]; + validity[3] = length[3]; + + Ok((values, validity)) +} + +pub fn array_to_page_v1( + array: &[Option], + options: &WriteOptions, + descriptor: &Descriptor, +) -> Result { + let (values, mut buffer) = unzip_option(array)?; + + buffer.extend_from_slice(&values); + + let statistics = if options.write_statistics { + let statistics = &PrimitiveStatistics { + primitive_type: descriptor.primitive_type.clone(), + null_count: Some((array.len() - array.iter().flatten().count()) as i64), + distinct_count: None, + max_value: array.iter().flatten().max_by(|x, y| x.ord(y)).copied(), + min_value: array.iter().flatten().min_by(|x, y| x.ord(y)).copied(), + } as &dyn Statistics; + Some(serialize_statistics(statistics)) + } else { + None + }; + + let header = DataPageHeaderV1 { + num_values: array.len() as i32, + encoding: Encoding::Plain.into(), + definition_level_encoding: Encoding::Rle.into(), + repetition_level_encoding: Encoding::Rle.into(), + statistics, + }; + + Ok(Page::Data(DataPage::new( + DataPageHeader::V1(header), + buffer, + descriptor.clone(), + Some(array.len()), + ))) +} diff --git a/crates/polars/tests/it/io/parquet/write/sidecar.rs b/crates/polars/tests/it/io/parquet/write/sidecar.rs new file mode 100644 index 0000000000000..f1c654a23d983 --- /dev/null +++ b/crates/polars/tests/it/io/parquet/write/sidecar.rs @@ -0,0 +1,55 @@ +use polars_parquet::parquet::error::Error; +use polars_parquet::parquet::metadata::SchemaDescriptor; +use polars_parquet::parquet::schema::types::{ParquetType, PhysicalType}; +use polars_parquet::parquet::write::{write_metadata_sidecar, FileWriter, Version, WriteOptions}; + +#[test] +fn basic() -> Result<(), Error> { + let schema = SchemaDescriptor::new( + "schema".to_string(), + vec![ParquetType::from_physical( + "c1".to_string(), + PhysicalType::Int32, + )], + ); + + let mut metadatas = vec![]; + for i in 0..10 { + // say we will write 10 files + let relative_path = format!("part-{i}.parquet"); + let writer = std::io::Cursor::new(vec![]); + let mut writer = FileWriter::new( + writer, + schema.clone(), + WriteOptions { + write_statistics: true, + version: Version::V2, + }, + None, + ); + writer.end(None)?; + let (_, mut metadata) = writer.into_inner_and_metadata(); + + // once done, we write their relative paths: + metadata.row_groups.iter_mut().for_each(|row_group| { + row_group + .columns + .iter_mut() + .for_each(|column| column.file_path = Some(relative_path.clone())) + }); + metadatas.push(metadata); + } + + // merge their row groups + let first = metadatas.pop().unwrap(); + let sidecar = metadatas.into_iter().fold(first, |mut acc, metadata| { + acc.row_groups.extend(metadata.row_groups); + acc + }); + + // and write the metadata on a separate file + let mut writer = std::io::Cursor::new(vec![]); + write_metadata_sidecar(&mut writer, &sidecar)?; + + Ok(()) +} diff --git a/crates/polars/tests/it/lazy/group_by.rs b/crates/polars/tests/it/lazy/group_by.rs index 4e24e1d24fa81..d8e20c804ca0e 100644 --- a/crates/polars/tests/it/lazy/group_by.rs +++ b/crates/polars/tests/it/lazy/group_by.rs @@ -126,18 +126,19 @@ fn test_group_by_agg_list_with_not_aggregated() -> PolarsResult<()> { } #[test] -#[cfg(all(feature = "dtype-duration", feature = "dtype-struct"))] +#[cfg(all(feature = "dtype-duration", feature = "dtype-decimal"))] fn test_logical_mean_partitioned_group_by_block() -> PolarsResult<()> { let _guard = SINGLE_LOCK.lock(); let df = df![ - "a" => [1, 1, 2], + "decimal" => [1, 1, 2], "duration" => [1000, 2000, 3000] ]?; let out = df .lazy() + .with_column(col("decimal").cast(DataType::Decimal(None, Some(2)))) .with_column(col("duration").cast(DataType::Duration(TimeUnit::Microseconds))) - .group_by([col("a")]) + .group_by([col("decimal")]) .agg([col("duration").mean()]) .sort("duration", Default::default()) .collect()?; diff --git a/crates/polars/tests/it/main.rs b/crates/polars/tests/it/main.rs index 8cf14da210c37..4395ce47028f7 100644 --- a/crates/polars/tests/it/main.rs +++ b/crates/polars/tests/it/main.rs @@ -6,4 +6,7 @@ mod lazy; mod schema; mod time; +mod arrow; +mod chunks; + pub static FOODS_CSV: &str = "../../examples/datasets/foods1.csv"; diff --git a/crates/polars/tests/it/time/date_range.rs b/crates/polars/tests/it/time/date_range.rs index 9de815368d77f..ff8df835cce29 100644 --- a/crates/polars/tests/it/time/date_range.rs +++ b/crates/polars/tests/it/time/date_range.rs @@ -1,6 +1,7 @@ use polars::export::chrono::NaiveDate; use polars::prelude::*; -use polars::time::{date_range, ClosedWindow, Duration}; +#[allow(unused_imports)] +use polars::time::date_range; #[test] fn test_time_units_9413() { diff --git a/docs/_build/snippets/under_construction.md b/docs/_build/snippets/under_construction.md index 9ac0d3fcc6139..00b0cc4af9225 100644 --- a/docs/_build/snippets/under_construction.md +++ b/docs/_build/snippets/under_construction.md @@ -1,4 +1,4 @@ !!! warning ":construction: Under Construction :construction:" This section is still under development. Want to help out? Consider contributing and making a [pull request](https://github.com/pola-rs/polars) to our repository. - Please read our [Contribution Guidelines](https://github.com/pola-rs/polars/blob/main/CONTRIBUTING.md) on how to proceed. + Please read our [contributing guide](https://docs.pola.rs/development/contributing/) on how to proceed. diff --git a/docs/data/alltypes_plain.parquet b/docs/data/alltypes_plain.parquet new file mode 100644 index 0000000000000..a63f5dca7c382 Binary files /dev/null and b/docs/data/alltypes_plain.parquet differ diff --git a/docs/development/contributing/index.md b/docs/development/contributing/index.md index eb1c5c7a3572a..809a149e5160a 100644 --- a/docs/development/contributing/index.md +++ b/docs/development/contributing/index.md @@ -128,7 +128,7 @@ When you have resolved your issue, [open a pull request](https://docs.github.com Please adhere to the following guidelines: - Start your pull request title with a [conventional commit](https://www.conventionalcommits.org/) tag. This helps us add your contribution to the right section of the changelog. We use the [Angular convention](https://github.com/angular/angular/blob/22b96b9/CONTRIBUTING.md#type). Scope can be `rust` and/or `python`, depending on your contribution. -- Use a descriptive title. This text will end up in the [changelog](https://github.com/pola-rs/polars/releases). +- Use a descriptive title starting with an uppercase letter. This text will end up in the [changelog](https://github.com/pola-rs/polars/releases). - In the pull request description, [link](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue) to the issue you were working on. - Add any relevant information to the description that you think may help the maintainers review your code. - Make sure your branch is [rebased](https://docs.github.com/en/get-started/using-git/about-git-rebase) against the latest version of the `main` branch. diff --git a/docs/development/versioning.md b/docs/development/versioning.md index 2d8009e8dbe5f..727048b439d2f 100644 --- a/docs/development/versioning.md +++ b/docs/development/versioning.md @@ -31,7 +31,7 @@ We know it takes time and energy for our users to keep up with new releases but, **A breaking change occurs when an existing component of the public API is changed or removed.** -A feature is part of the public API if it is documented in the [API reference](https://docs.pola.rs/py-polars/html/reference/). +A feature is part of the public API if it is documented in the [API reference](https://docs.pola.rs/py-polars/html/reference/index.html). Examples of breaking changes: diff --git a/docs/mlc-config.json b/docs/mlc-config.json index e77aed6c4d0ed..e9a807932ac38 100644 --- a/docs/mlc-config.json +++ b/docs/mlc-config.json @@ -2,6 +2,8 @@ "ignorePatterns": [ { "pattern": "^https://crates.io/" + },{ + "pattern": "^https://stackoverflow.com/" } ] } diff --git a/docs/src/python/user-guide/misc/multiprocess.py b/docs/src/python/user-guide/misc/multiprocess.py index 55aec52d6b9f0..6876b05537526 100644 --- a/docs/src/python/user-guide/misc/multiprocess.py +++ b/docs/src/python/user-guide/misc/multiprocess.py @@ -51,6 +51,7 @@ def main(): # --8<-- [end:example1] """ + # --8<-- [start:example2] import multiprocessing import polars as pl diff --git a/docs/src/rust/user-guide/expressions/casting.rs b/docs/src/rust/user-guide/expressions/casting.rs index 3729ca0492ca7..b18ca19022dfd 100644 --- a/docs/src/rust/user-guide/expressions/casting.rs +++ b/docs/src/rust/user-guide/expressions/casting.rs @@ -1,5 +1,4 @@ // --8<-- [start:setup] -use polars::lazy::dsl::StrptimeOptions; use polars::prelude::*; // --8<-- [end:setup] diff --git a/docs/src/rust/user-guide/expressions/structs.rs b/docs/src/rust/user-guide/expressions/structs.rs index 502f423fdf0dd..01c08eaf3d7f8 100644 --- a/docs/src/rust/user-guide/expressions/structs.rs +++ b/docs/src/rust/user-guide/expressions/structs.rs @@ -1,5 +1,4 @@ // --8<-- [start:setup] -use polars::lazy::dsl::len; use polars::prelude::*; // --8<-- [end:setup] fn main() -> Result<(), Box> { diff --git a/docs/src/rust/user-guide/getting-started/reading-writing.rs b/docs/src/rust/user-guide/getting-started/reading-writing.rs index dad5e8713d248..bc021e9a21de9 100644 --- a/docs/src/rust/user-guide/getting-started/reading-writing.rs +++ b/docs/src/rust/user-guide/getting-started/reading-writing.rs @@ -13,7 +13,8 @@ fn main() -> Result<(), Box> { NaiveDate::from_ymd_opt(2025, 1, 2).unwrap().and_hms_opt(0, 0, 0).unwrap(), NaiveDate::from_ymd_opt(2025, 1, 3).unwrap().and_hms_opt(0, 0, 0).unwrap(), ], - "float" => &[4.0, 5.0, 6.0] + "float" => &[4.0, 5.0, 6.0], + "string" => &["a", "b", "c"], ) .unwrap(); println!("{}", df); diff --git a/docs/src/rust/user-guide/transformations/pivot.rs b/docs/src/rust/user-guide/transformations/pivot.rs index 2115b528579cf..804ead13f056c 100644 --- a/docs/src/rust/user-guide/transformations/pivot.rs +++ b/docs/src/rust/user-guide/transformations/pivot.rs @@ -7,20 +7,29 @@ fn main() -> Result<(), Box> { // --8<-- [start:df] let df = df!( "foo"=> ["A", "A", "B", "B", "C"], - "N"=> [1, 2, 2, 4, 2], "bar"=> ["k", "l", "m", "n", "o"], + "N"=> [1, 2, 2, 4, 2], )?; println!("{}", &df); // --8<-- [end:df] // --8<-- [start:eager] - let out = pivot(&df, ["N"], ["foo"], ["bar"], false, None, None)?; + let out = pivot(&df, ["foo"], ["bar"], Some(["N"]), false, None, None)?; println!("{}", &out); // --8<-- [end:eager] // --8<-- [start:lazy] let q = df.lazy(); - let q2 = pivot(&q.collect()?, ["N"], ["foo"], ["bar"], false, None, None)?.lazy(); + let q2 = pivot( + &q.collect()?, + ["foo"], + ["bar"], + Some(["N"]), + false, + None, + None, + )? + .lazy(); let out = q2.collect()?; println!("{}", &out); // --8<-- [end:lazy] diff --git a/docs/src/rust/user-guide/transformations/time-series/filter.rs b/docs/src/rust/user-guide/transformations/time-series/filter.rs index 56c6589c1555b..06ce39eb0c5f0 100644 --- a/docs/src/rust/user-guide/transformations/time-series/filter.rs +++ b/docs/src/rust/user-guide/transformations/time-series/filter.rs @@ -1,7 +1,6 @@ // --8<-- [start:setup] use chrono::prelude::*; use polars::io::prelude::*; -use polars::lazy::dsl::StrptimeOptions; use polars::prelude::*; // --8<-- [end:setup] diff --git a/docs/src/rust/user-guide/transformations/time-series/parsing.rs b/docs/src/rust/user-guide/transformations/time-series/parsing.rs index b35a522157419..3462943d15afb 100644 --- a/docs/src/rust/user-guide/transformations/time-series/parsing.rs +++ b/docs/src/rust/user-guide/transformations/time-series/parsing.rs @@ -1,6 +1,5 @@ // --8<-- [start:setup] use polars::io::prelude::*; -use polars::lazy::dsl::StrptimeOptions; use polars::prelude::*; // --8<-- [end:setup] diff --git a/docs/src/rust/user-guide/transformations/time-series/rolling.rs b/docs/src/rust/user-guide/transformations/time-series/rolling.rs index fc81f34412bbd..c9b7e58906ccb 100644 --- a/docs/src/rust/user-guide/transformations/time-series/rolling.rs +++ b/docs/src/rust/user-guide/transformations/time-series/rolling.rs @@ -1,7 +1,6 @@ // --8<-- [start:setup] use chrono::prelude::*; use polars::io::prelude::*; -use polars::lazy::dsl::GetOutput; use polars::prelude::*; // --8<-- [end:setup] diff --git a/docs/user-guide/expressions/plugins.md b/docs/user-guide/expressions/plugins.md index 1384eca05e294..60c5aedfb7afb 100644 --- a/docs/user-guide/expressions/plugins.md +++ b/docs/user-guide/expressions/plugins.md @@ -37,7 +37,7 @@ crate-type = ["cdylib"] [dependencies] polars = { version = "*" } -pyo3 = { version = "*", features = ["extension-module"] } +pyo3 = { version = "*", features = ["extension-module", "abi-py38"] } pyo3-polars = { version = "*", features = ["derive"] } serde = { version = "*", features = ["derive"] } ``` @@ -92,24 +92,22 @@ expression in batches. Whereas for other operations this would not be allowed, t ```python # expression_lib/__init__.py +from pathlib import Path +from typing import TYPE_CHECKING + import polars as pl +from polars.plugins import register_plugin_function from polars.type_aliases import IntoExpr -from polars.utils.udfs import _get_shared_lib_location - -# Boilerplate needed to inform Polars of the location of binary wheel. -lib = _get_shared_lib_location(__file__) - -@pl.api.register_expr_namespace("language") -class Language: - def __init__(self, expr: pl.Expr): - self._expr = expr - - def pig_latinnify(self) -> pl.Expr: - return self._expr._register_plugin( - lib=lib, - symbol="pig_latinnify", - is_elementwise=True, - ) + + +def pig_latinnify(expr: IntoExpr) -> pl.Expr: + """Pig-latinnify expression.""" + return register_plugin_function( + plugin_path=Path(__file__).parent, + function_name="pig_latinnify", + args=expr, + is_elementwise=True, + ) ``` We can then compile this library in our environment by installing `maturin` and running `maturin develop --release`. @@ -118,15 +116,19 @@ And that's it. Our expression is ready to use! ```python import polars as pl -from expression_lib import Language +from expression_lib import pig_latinnify df = pl.DataFrame( { "convert": ["pig", "latin", "is", "silly"], } ) +out = df.with_columns(pig_latin=pig_latinnify("convert")) +``` +Alternatively, you can [register a custom namespace](https://docs.pola.rs/py-polars/html/reference/api/polars.api.register_expr_namespace.html#polars.api.register_expr_namespace), which enables you to write: +```python out = df.with_columns( pig_latin=pl.col("convert").language.pig_latinnify(), ) @@ -173,33 +175,28 @@ fn append_kwargs(input: &[Series], kwargs: MyKwargs) -> PolarsResult { On the Python side the kwargs can be passed when we register the plugin. ```python -@pl.api.register_expr_namespace("my_expr") -class MyCustomExpr: - def __init__(self, expr: pl.Expr): - self._expr = expr - - def append_args( - self, - float_arg: float, - integer_arg: int, - string_arg: str, - boolean_arg: bool, - ) -> pl.Expr: - """ - This example shows how arguments other than `Series` can be used. - """ - return self._expr._register_plugin( - lib=lib, - args=[], - kwargs={ - "float_arg": float_arg, - "integer_arg": integer_arg, - "string_arg": string_arg, - "boolean_arg": boolean_arg, - }, - symbol="append_kwargs", - is_elementwise=True, - ) +def append_args( + expr: IntoExpr, + float_arg: float, + integer_arg: int, + string_arg: str, + boolean_arg: bool, +) -> pl.Expr: + """ + This example shows how arguments other than `Series` can be used. + """ + return register_plugin_function( + plugin_path=Path(__file__).parent, + function_name="append_kwargs", + args=expr, + kwargs={ + "float_arg": float_arg, + "integer_arg": integer_arg, + "string_arg": string_arg, + "boolean_arg": boolean_arg, + }, + is_elementwise=True, + ) ``` ## Output data types @@ -242,14 +239,24 @@ fn haversine(inputs: &[Series]) -> PolarsResult { } ``` -That's all you need to know to get started. Take a look at this [repo](https://github.com/pola-rs/pyo3-polars/tree/main/example/derive_expression) to see how this all fits together. +That's all you need to know to get started. Take a look at [this repo](https://github.com/pola-rs/pyo3-polars/tree/main/example/derive_expression) to see how this all fits together, and at [this tutorial](https://marcogorelli.github.io/polars-plugins-tutorial/) +to gain a more thorough understanding. ## Community plugins -Here is a curated (non-exhaustive) list of community implemented plugins. +Here is a curated (non-exhaustive) list of community-implemented plugins. - [polars-xdt](https://github.com/pola-rs/polars-xdt) Polars plugin with extra datetime-related functionality which isn't quite in-scope for the main library - [polars-distance](https://github.com/ion-elgreco/polars-distance) Polars plugin for pairwise distance functions - [polars-ds](https://github.com/abstractqqq/polars_ds_extension) Polars extension aiming to simplify common numerical/string data analysis procedures - [polars-hash](https://github.com/ion-elgreco/polars-hash) Stable non-cryptographic and cryptographic hashing functions for Polars +- [polars-reverse-geocode](https://github.com/MarcoGorelli/polars-reverse-geocode) Offline reverse geocoder for finding the closest city + to a given (latitude, longitude) pair + +## Other material + +- [Ritchie Vink - Keynote on Polars Plugins](https://youtu.be/jKW-CBV7NUM) +- [Polars plugins tutorial](https://marcogorelli.github.io/polars-plugins-tutorial/) Learn how to write a plugin by + going through some very simple and minimal examples +- [cookiecutter-polars-plugin](https://github.com/MarcoGorelli/cookiecutter-polars-plugins) Project template for Polars Plugins diff --git a/docs/user-guide/getting-started.md b/docs/user-guide/getting-started.md index 4a841961986dd..2a601597bb3df 100644 --- a/docs/user-guide/getting-started.md +++ b/docs/user-guide/getting-started.md @@ -126,7 +126,7 @@ We will create a new `DataFrame` for the Group by functionality. This new `DataF print(df2) ``` -{{code_block('user-guide/bgetting-startedasics/expressions','group_by',['group_by'])}} +{{code_block('user-guide/getting-started/expressions','group_by',['group_by'])}} ```python exec="on" result="text" session="getting-started/expressions" print( diff --git a/docs/user-guide/installation.md b/docs/user-guide/installation.md index 30eeb68b45751..83aa684bf92e3 100644 --- a/docs/user-guide/installation.md +++ b/docs/user-guide/installation.md @@ -123,7 +123,7 @@ The opt-in features are: - `rows` - Create `DataFrame` from rows and extract rows from `DataFrames`. And activates `pivot` and `transpose` operations - `join_asof` - Join ASOF, to join on nearest keys instead of exact equality match. - - `cross_join` - Create the cartesian product of two DataFrames. + - `cross_join` - Create the Cartesian product of two DataFrames. - `semi_anti_join` - SEMI and ANTI joins. - `group_by_list` - Allow group by operation on keys of type List. - `row_hash` - Utility to hash DataFrame rows to UInt64Chunked diff --git a/docs/user-guide/transformations/joins.md b/docs/user-guide/transformations/joins.md index 07dee43127bff..1a6f293371911 100644 --- a/docs/user-guide/transformations/joins.md +++ b/docs/user-guide/transformations/joins.md @@ -2,7 +2,7 @@ ## Join strategies -Polars supports the following join strategies by specifying the `strategy` argument: +Polars supports the following join strategies by specifying the `how` argument: | Strategy | Description | | -------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | @@ -10,7 +10,6 @@ Polars supports the following join strategies by specifying the `strategy` argum | `left` | Returns all rows in the left dataframe, whether or not a match in the right-frame is found. Non-matching rows have their right columns null-filled. | | `outer` | Returns all rows from both the left and right dataframe. If no match is found in one frame, columns from the other frame are null-filled. | | `cross` | Returns the Cartesian product of all rows from the left frame with all rows from the right frame. Duplicates rows are retained; the table length of `A` cross-joined with `B` is always `len(A) × len(B)`. | -| `asof` | A left-join in which the match is performed on the _nearest_ key rather than on equal keys. | | `semi` | Returns all rows from the left frame in which the join key is also present in the right frame. | | `anti` | Returns all rows from the left frame in which the join key is _not_ present in the right frame. | @@ -65,7 +64,7 @@ The `outer` join produces a `DataFrame` that contains all the rows from both `Da ### Cross join -A `cross` join is a cartesian product of the two `DataFrames`. This means that every row in the left `DataFrame` is joined with every row in the right `DataFrame`. The `cross` join is useful for creating a `DataFrame` with all possible combinations of the columns in two `DataFrames`. Let's take for example the following two `DataFrames`. +A `cross` join is a Cartesian product of the two `DataFrames`. This means that every row in the left `DataFrame` is joined with every row in the right `DataFrame`. The `cross` join is useful for creating a `DataFrame` with all possible combinations of the columns in two `DataFrames`. Let's take for example the following two `DataFrames`. {{code_block('user-guide/transformations/joins','df3',['DataFrame'])}} @@ -139,10 +138,10 @@ Continuing this example, an alternative question might be: which of the cars hav --8<-- "python/user-guide/transformations/joins.py:anti" ``` -### Asof join +## Asof join An `asof` join is like a left join except that we match on nearest key rather than equal keys. -In Polars we can do an asof join with the `join` method and specifying `strategy="asof"`. However, for more flexibility we can use the `join_asof` method. +In Polars we can do an asof join with the `join_asof` method. Consider the following scenario: a stock market broker has a `DataFrame` called `df_trades` showing transactions it has made for different stocks. diff --git a/py-polars/Cargo.toml b/py-polars/Cargo.toml index 3597cd9ed34db..ab8f2fd5c0076 100644 --- a/py-polars/Cargo.toml +++ b/py-polars/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "py-polars" -version = "0.20.8" +version = "0.20.16-rc.1" edition = "2021" [lib] diff --git a/py-polars/Makefile b/py-polars/Makefile index e835c041bbe24..420c7f75abf91 100644 --- a/py-polars/Makefile +++ b/py-polars/Makefile @@ -54,8 +54,13 @@ build-opt-native: .venv ## Same as build-opt, except with native CPU optimizati build-release-native: .venv ## Same as build-release, except with native CPU optimizations turned on @$(MAKE) -s -C .. $@ +.PHONY: lint +lint: .venv ## Run lint checks + $(VENV_BIN)/ruff check + -$(VENV_BIN)/mypy + .PHONY: fmt -fmt: .venv ## Run autoformatting and linting +fmt: .venv ## Run autoformatting (and lint) $(VENV_BIN)/ruff check $(VENV_BIN)/ruff format $(VENV_BIN)/typos @@ -68,7 +73,7 @@ clippy: ## Run clippy cargo clippy --locked -- -D warnings -D clippy::dbg_macro .PHONY: pre-commit -pre-commit: fmt clippy ## Run all code quality checks +pre-commit: fmt clippy ## Run all code formatting and lint/quality checks .PHONY: test test: .venv build ## Run fast unittests diff --git a/py-polars/debug/launch.py b/py-polars/debug/launch.py index 95352e4eafa30..2a469c7e9485c 100644 --- a/py-polars/debug/launch.py +++ b/py-polars/debug/launch.py @@ -47,8 +47,9 @@ def launch_debugging() -> None: if not found: msg = ( "Cannot locate pid definition in launch.json for Rust LLDB configuration. " - "Please follow the instructions in CONTRIBUTING.md for creating the " - "launch configuration." + "Please follow the instructions in the debugging section of the " + "contributing guide (https://docs.pola.rs/development/contributing/ide/#debugging) " + "for creating the launch configuration." ) raise RuntimeError(msg) diff --git a/py-polars/docs/requirements-docs.txt b/py-polars/docs/requirements-docs.txt index f1f88d7e29407..7da7161639ff0 100644 --- a/py-polars/docs/requirements-docs.txt +++ b/py-polars/docs/requirements-docs.txt @@ -1,5 +1,3 @@ ---prefer-binary - numpy pandas pyarrow @@ -16,6 +14,7 @@ sphinx-autosummary-accessors==2023.4.0 sphinx-copybutton==0.5.2 sphinx-design==0.5.0 sphinx-favicon==1.0.1 +sphinx_reredirects==0.1.3 sphinx-toolbox==3.5.0 livereload==2.6.3 diff --git a/py-polars/docs/source/_static/css/custom.css b/py-polars/docs/source/_static/css/custom.css index 9cdd3b3591d84..7732a4a2a2d27 100644 --- a/py-polars/docs/source/_static/css/custom.css +++ b/py-polars/docs/source/_static/css/custom.css @@ -24,26 +24,36 @@ html[data-theme="dark"] { --pst-color-border: #444444; } +/* add subtle gradients to sidebar and card elements */ div.bd-sidebar-primary { background-image: linear-gradient(90deg, var(--pst-gradient-sidebar-left) 0%, var(--pst-gradient-sidebar-right) 100%); } +div.sd-card { + background-image: linear-gradient(0deg, var(--pst-gradient-sidebar-left) 0%, var(--pst-gradient-sidebar-right) 100%); +} +/* match docs footer colour to the header */ footer.bd-footer { background-color: var(--pst-color-on-background); } /* - We're not currently doing anything meaningful with the right - ToC, so hide until there's actually something to put there... + we're not currently doing anything meaningful with the + right toc, so hide until there's something to put there */ div.bd-sidebar-secondary { display: none; } - label.sidebar-toggle.secondary-toggle { display: none !important; } +/* fix visited link colour */ a:visited { color: var(--pst-color-link); } + +/* fix ugly navbar scrollbar display */ +.sidebar-primary-items__end { + margin: 0 !important; +} diff --git a/py-polars/docs/source/conf.py b/py-polars/docs/source/conf.py index fe70f5c13641e..99ecc82295748 100644 --- a/py-polars/docs/source/conf.py +++ b/py-polars/docs/source/conf.py @@ -21,12 +21,14 @@ # Add py-polars directory sys.path.insert(0, str(Path("../..").resolve())) + # -- Project information ----------------------------------------------------- project = "Polars" author = "Ritchie Vink" copyright = f"2020, {author}" + # -- General configuration --------------------------------------------------- extensions = [ @@ -44,6 +46,7 @@ "sphinx_copybutton", "sphinx_design", "sphinx_favicon", + "sphinx_reredirects", "sphinx_toolbox.more_autodoc.overloads", ] @@ -67,6 +70,7 @@ # https://sphinx-toolbox.readthedocs.io/en/latest/ overloads_location = ["bottom"] + # -- Extension settings ----------------------------------------------------- # sphinx.ext.intersphinx - link to other projects' documentation @@ -89,6 +93,10 @@ copybutton_prompt_text = r">>> |\.\.\. " copybutton_prompt_is_regexp = True +# redirect empty root to the actual landing page +redirects = {"index": "reference/index.html"} + + # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. diff --git a/py-polars/docs/source/reference/api.rst b/py-polars/docs/source/reference/api.rst index 54e8ed02b4b56..8cd9dc77475b8 100644 --- a/py-polars/docs/source/reference/api.rst +++ b/py-polars/docs/source/reference/api.rst @@ -7,7 +7,7 @@ Providing new functionality --------------------------- These functions allow you to register custom functionality in a dedicated -namespace on the underlying polars classes without requiring subclassing +namespace on the underlying Polars classes without requiring subclassing or mixins. Expr, DataFrame, LazyFrame, and Series are all supported targets. This feature is primarily intended for use by library authors providing @@ -29,7 +29,7 @@ Available registrations .. note:: - You cannot override existing polars namespaces (such as ``.str`` or ``.dt``), and attempting to do so + You cannot override existing Polars namespaces (such as ``.str`` or ``.dt``), and attempting to do so will raise an `AttributeError `_. However, you *can* override other custom namespaces (which will only generate a `UserWarning `_). diff --git a/py-polars/docs/source/reference/config.rst b/py-polars/docs/source/reference/config.rst index 452ecd98c25c7..a8137cdff7a07 100644 --- a/py-polars/docs/source/reference/config.rst +++ b/py-polars/docs/source/reference/config.rst @@ -34,8 +34,8 @@ Config options Config.set_trim_decimal_zeros Config.set_verbose -Config load, save, and current state ------------------------------------- +Config load, save, state +------------------------ .. autosummary:: :toctree: api/ @@ -81,8 +81,8 @@ explicitly calling one or more of the available "set\_" methods on it... with pl.Config(verbose=True): do_various_things() -Use as a function decorator ---------------------------- +Use as a decorator +------------------ In the same vein, you can also use ``Config`` as a function decorator to temporarily set options for the duration of the function call: diff --git a/py-polars/docs/source/reference/datatypes.rst b/py-polars/docs/source/reference/datatypes.rst index 3e538998b0028..695923e928859 100644 --- a/py-polars/docs/source/reference/datatypes.rst +++ b/py-polars/docs/source/reference/datatypes.rst @@ -53,8 +53,8 @@ Other :toctree: api/ :nosignatures: - Boolean Binary + Boolean Categorical Enum Null diff --git a/py-polars/docs/source/reference/expressions/functions.rst b/py-polars/docs/source/reference/expressions/functions.rst index 3fad1cb7f989d..d240454e136b5 100644 --- a/py-polars/docs/source/reference/expressions/functions.rst +++ b/py-polars/docs/source/reference/expressions/functions.rst @@ -2,7 +2,7 @@ Functions ========= -These functions are available from the polars module root and can be used as expressions, and sometimes also in eager contexts. +These functions are available from the Polars module root and can be used as expressions, and sometimes also in eager contexts. ---- diff --git a/py-polars/docs/source/reference/expressions/index.rst b/py-polars/docs/source/reference/expressions/index.rst index 6c87796c30819..5f9b8b541dadb 100644 --- a/py-polars/docs/source/reference/expressions/index.rst +++ b/py-polars/docs/source/reference/expressions/index.rst @@ -2,7 +2,7 @@ Expressions =========== -This page gives an overview of all public polars expressions. +This page gives an overview of all public Polars expressions. .. toctree:: :maxdepth: 2 diff --git a/py-polars/docs/source/reference/expressions/meta.rst b/py-polars/docs/source/reference/expressions/meta.rst index c2bfed4728cf0..22868fad271db 100644 --- a/py-polars/docs/source/reference/expressions/meta.rst +++ b/py-polars/docs/source/reference/expressions/meta.rst @@ -17,5 +17,6 @@ The following methods are available under the `expr.meta` attribute. Expr.meta.pop Expr.meta.tree_format Expr.meta.root_names + Expr.meta.serialize Expr.meta.undo_aliases Expr.meta.write_json diff --git a/py-polars/docs/source/reference/expressions/miscellaneous.rst b/py-polars/docs/source/reference/expressions/miscellaneous.rst index a1997ccb9031f..c0ea4d2caf1bc 100644 --- a/py-polars/docs/source/reference/expressions/miscellaneous.rst +++ b/py-polars/docs/source/reference/expressions/miscellaneous.rst @@ -6,5 +6,6 @@ Miscellaneous .. autosummary:: :toctree: api/ - Expr.from_json - Expr.set_sorted + Expr.deserialize + Expr.from_json + Expr.set_sorted diff --git a/py-polars/docs/source/reference/index.rst b/py-polars/docs/source/reference/index.rst index d99d14bb5565b..70b48dc5399c7 100644 --- a/py-polars/docs/source/reference/index.rst +++ b/py-polars/docs/source/reference/index.rst @@ -1,24 +1,109 @@ -============= -API reference -============= - -This page gives an overview of all public polars objects, functions and -methods. All classes and functions exposed in ``polars.*`` namespace are public. - -.. toctree:: - :maxdepth: 2 - - io - series/index - dataframe/index - lazyframe/index - expressions/index - selectors - api - functions - datatypes - config - exceptions - testing - sql - metadata +==================== +Python API reference +==================== + +This page gives a high-level overview of all public Polars objects, functions and +methods. All classes and functions exposed in the ``polars.*`` namespace are public. + + +.. grid:: + + .. grid-item-card:: + + .. toctree:: + :maxdepth: 2 + + dataframe/index + + .. grid-item-card:: + + .. toctree:: + :maxdepth: 2 + + lazyframe/index + + .. grid-item-card:: + + .. toctree:: + :maxdepth: 2 + + series/index + + +.. grid:: + + .. grid-item-card:: + + .. toctree:: + :maxdepth: 2 + + expressions/index + selectors + + .. grid-item-card:: + + .. toctree:: + :maxdepth: 2 + + functions + + .. grid-item-card:: + + .. toctree:: + :maxdepth: 2 + + datatypes + + +.. grid:: + + .. grid-item-card:: + + .. toctree:: + :maxdepth: 2 + + io + + .. grid-item-card:: + + .. toctree:: + :maxdepth: 2 + + config + + .. grid-item-card:: + + .. toctree:: + :maxdepth: 2 + + api + plugins + + +.. grid:: + + .. grid-item-card:: + + .. toctree:: + :maxdepth: 2 + + sql + + .. grid-item-card:: + + .. toctree:: + :maxdepth: 1 + + exceptions + + .. toctree:: + :maxdepth: 2 + + testing + + .. grid-item-card:: + + .. toctree:: + :maxdepth: 1 + + metadata diff --git a/py-polars/docs/source/reference/io.rst b/py-polars/docs/source/reference/io.rst index efc9e96603a86..d3c45469f94a7 100644 --- a/py-polars/docs/source/reference/io.rst +++ b/py-polars/docs/source/reference/io.rst @@ -3,6 +3,14 @@ Input/output ============ .. currentmodule:: polars +Avro +~~~~ +.. autosummary:: + :toctree: api/ + + read_avro + DataFrame.write_avro + CSV ~~~ .. autosummary:: @@ -14,29 +22,14 @@ CSV DataFrame.write_csv LazyFrame.sink_csv -Feather/ IPC -~~~~~~~~~~~~ -.. autosummary:: - :toctree: api/ - - read_ipc - read_ipc_stream - scan_ipc - read_ipc_schema - DataFrame.write_ipc - DataFrame.write_ipc_stream - LazyFrame.sink_ipc +.. currentmodule:: polars.io.csv.batched_reader -Parquet -~~~~~~~ .. autosummary:: :toctree: api/ - read_parquet - scan_parquet - read_parquet_schema - DataFrame.write_parquet - LazyFrame.sink_parquet + BatchedCsvReader.next_batches + +.. currentmodule:: polars Database ~~~~~~~~ @@ -47,27 +40,16 @@ Database read_database_uri DataFrame.write_database -JSON -~~~~ -.. autosummary:: - :toctree: api/ - - read_json - read_ndjson - scan_ndjson - DataFrame.write_json - DataFrame.write_ndjson - LazyFrame.sink_ndjson - -AVRO -~~~~ +Delta Lake +~~~~~~~~~~ .. autosummary:: :toctree: api/ - read_avro - DataFrame.write_avro + read_delta + scan_delta + DataFrame.write_delta -Spreadsheet +Excel / ODS ~~~~~~~~~~~ .. autosummary:: :toctree: api/ @@ -76,39 +58,54 @@ Spreadsheet read_ods DataFrame.write_excel -Apache Iceberg -~~~~~~~~~~~~~~ +Feather / IPC +~~~~~~~~~~~~~ .. autosummary:: :toctree: api/ - scan_iceberg + read_ipc + read_ipc_schema + read_ipc_stream + scan_ipc + DataFrame.write_ipc + DataFrame.write_ipc_stream + LazyFrame.sink_ipc -Delta Lake -~~~~~~~~~~ +Iceberg +~~~~~~~ .. autosummary:: :toctree: api/ - scan_delta - read_delta - DataFrame.write_delta - -Datasets -~~~~~~~~ -Connect to pyarrow datasets. + scan_iceberg +JSON +~~~~ .. autosummary:: :toctree: api/ - scan_pyarrow_dataset + read_json + read_ndjson + scan_ndjson + DataFrame.write_json + DataFrame.write_ndjson + LazyFrame.sink_ndjson + +Parquet +~~~~~~~ +.. autosummary:: + :toctree: api/ + read_parquet + read_parquet_schema + scan_parquet + DataFrame.write_parquet + LazyFrame.sink_parquet -BatchedCsvReader +PyArrow Datasets ~~~~~~~~~~~~~~~~ -This reader comes available by calling `pl.read_csv_batched`. - -.. currentmodule:: polars.io.csv.batched_reader +Connect to pyarrow datasets. .. autosummary:: :toctree: api/ - BatchedCsvReader.next_batches + scan_pyarrow_dataset diff --git a/py-polars/docs/source/reference/plugins.rst b/py-polars/docs/source/reference/plugins.rst new file mode 100644 index 0000000000000..e49d69f0a1193 --- /dev/null +++ b/py-polars/docs/source/reference/plugins.rst @@ -0,0 +1,15 @@ +======= +Plugins +======= +.. currentmodule:: polars + +Plugins allow for extending Polars' functionality. See the +`user guide `_ for more information +and resources. + +Available plugin utility functions are: + +.. automodule:: polars.plugins + :members: + :autosummary: + :autosummary-no-titles: diff --git a/py-polars/docs/source/reference/sql.rst b/py-polars/docs/source/reference/sql.rst index 2b1c323e7148e..bf28f9cc6e20b 100644 --- a/py-polars/docs/source/reference/sql.rst +++ b/py-polars/docs/source/reference/sql.rst @@ -1,6 +1,6 @@ -=== -SQL -=== +============= +SQL Interface +============= .. currentmodule:: polars .. py:class:: SQLContext diff --git a/py-polars/polars/__init__.py b/py-polars/polars/__init__.py index d7f093484221f..7a898d0ebc3f0 100644 --- a/py-polars/polars/__init__.py +++ b/py-polars/polars/__init__.py @@ -17,7 +17,11 @@ __register_startup_deps() -from polars import api +from polars import api, exceptions, plugins, selectors +from polars._utils.polars_version import get_polars_version as _get_polars_version + +# TODO: remove need for importing wrap utils at top level +from polars._utils.wrap import wrap_df, wrap_s # noqa: F401 from polars.config import Config from polars.convert import ( from_arrow, @@ -214,10 +218,6 @@ using_string_cache, ) from polars.type_aliases import PolarsDataType -from polars.utils._polars_version import get_polars_version as _get_polars_version - -# TODO: remove need for importing wrap utils at top level -from polars.utils._wrap import wrap_df, wrap_s # noqa: F401 __version__: str = _get_polars_version() del _get_polars_version @@ -225,6 +225,7 @@ __all__ = [ "api", "exceptions", + "plugins", # exceptions/errors "ArrowError", "ColumnNotFoundError", diff --git a/py-polars/polars/_cpu_check.py b/py-polars/polars/_cpu_check.py index e6033eac91ef6..77a343befe601 100644 --- a/py-polars/polars/_cpu_check.py +++ b/py-polars/polars/_cpu_check.py @@ -122,10 +122,6 @@ class CPUID_struct(ctypes.Structure): class CPUID: def __init__(self) -> None: - if _POLARS_ARCH != "x86-64": - msg = "CPUID is only available for x86" - raise SystemError(msg) - if _IS_WINDOWS: if _IS_64BIT: # VirtualAlloc seems to fail under some weird @@ -187,11 +183,7 @@ def __del__(self) -> None: self.win.VirtualFree(self.addr, 0, _MEM_RELEASE) -def read_cpu_flags() -> dict[str, bool]: - # Right now we only enable extra feature flags for x86. - if _POLARS_ARCH != "x86-64": - return {} - +def _read_cpu_flags() -> dict[str, bool]: # CPU flags from https://en.wikipedia.org/wiki/CPUID cpuid = CPUID() cpuid1 = cpuid(1, 0) @@ -205,6 +197,7 @@ def read_cpu_flags() -> dict[str, bool]: "sse4.1": bool(cpuid1.ecx & (1 << 19)), "sse4.2": bool(cpuid1.ecx & (1 << 20)), "popcnt": bool(cpuid1.ecx & (1 << 23)), + "pclmulqdq": bool(cpuid1.ecx & (1 << 1)), "avx": bool(cpuid1.ecx & (1 << 28)), "bmi1": bool(cpuid7.ebx & (1 << 3)), "bmi2": bool(cpuid7.ebx & (1 << 8)), @@ -222,12 +215,12 @@ def check_cpu_flags() -> None: return expected_cpu_flags = [f.lstrip("+") for f in _POLARS_FEATURE_FLAGS.split(",")] - supported_cpu_flags = read_cpu_flags() + supported_cpu_flags = _read_cpu_flags() missing_features = [] for f in expected_cpu_flags: if f not in supported_cpu_flags: - msg = f'unknown feature flag "{f}"' + msg = f"unknown feature flag: {f!r}" raise RuntimeError(msg) if not supported_cpu_flags[f]: diff --git a/py-polars/polars/_utils/__init__.py b/py-polars/polars/_utils/__init__.py new file mode 100644 index 0000000000000..266cfa26ff5ae --- /dev/null +++ b/py-polars/polars/_utils/__init__.py @@ -0,0 +1,37 @@ +""" +Utility functions. + +Functions that are part of the public API are re-exported here. +""" + +from polars._utils.convert import ( + date_to_int, + datetime_to_int, + time_to_int, + timedelta_to_int, + to_py_date, + to_py_datetime, + to_py_decimal, + to_py_time, + to_py_timedelta, +) +from polars._utils.scan import _execute_from_rust +from polars._utils.various import NoDefault, _polars_warn, is_column, no_default + +__all__ = [ + "NoDefault", + "is_column", + "no_default", + # Required for Rust bindings + "date_to_int", + "datetime_to_int", + "time_to_int", + "timedelta_to_int", + "_execute_from_rust", + "_polars_warn", + "to_py_date", + "to_py_datetime", + "to_py_decimal", + "to_py_time", + "to_py_timedelta", +] diff --git a/py-polars/polars/utils/_async.py b/py-polars/polars/_utils/async_.py similarity index 98% rename from py-polars/polars/utils/_async.py rename to py-polars/polars/_utils/async_.py index 60104271712e5..e90ae5d614cab 100644 --- a/py-polars/polars/utils/_async.py +++ b/py-polars/polars/_utils/async_.py @@ -2,8 +2,8 @@ from typing import TYPE_CHECKING, Any, Awaitable, Generator, Generic, TypeVar +from polars._utils.wrap import wrap_df from polars.dependencies import _GEVENT_AVAILABLE -from polars.utils._wrap import wrap_df if TYPE_CHECKING: from asyncio.futures import Future diff --git a/py-polars/polars/_utils/construction/__init__.py b/py-polars/polars/_utils/construction/__init__.py new file mode 100644 index 0000000000000..35f2232b2e0e7 --- /dev/null +++ b/py-polars/polars/_utils/construction/__init__.py @@ -0,0 +1,48 @@ +from polars._utils.construction.dataframe import ( + arrow_to_pydf, + dataframe_to_pydf, + dict_to_pydf, + iterable_to_pydf, + numpy_to_pydf, + pandas_to_pydf, + sequence_to_pydf, + series_to_pydf, +) +from polars._utils.construction.other import ( + coerce_arrow, + numpy_to_idxs, + pandas_series_to_arrow, +) +from polars._utils.construction.series import ( + arrow_to_pyseries, + dataframe_to_pyseries, + iterable_to_pyseries, + numpy_to_pyseries, + pandas_to_pyseries, + sequence_to_pyseries, + series_to_pyseries, +) + +__all__ = [ + # dataframe + "arrow_to_pydf", + "dataframe_to_pydf", + "dict_to_pydf", + "iterable_to_pydf", + "numpy_to_pydf", + "pandas_to_pydf", + "sequence_to_pydf", + "series_to_pydf", + # series + "arrow_to_pyseries", + "dataframe_to_pyseries", + "iterable_to_pyseries", + "numpy_to_pyseries", + "pandas_to_pyseries", + "sequence_to_pyseries", + "series_to_pyseries", + # other + "coerce_arrow", + "numpy_to_idxs", + "pandas_series_to_arrow", +] diff --git a/py-polars/polars/utils/_construction.py b/py-polars/polars/_utils/construction/dataframe.py similarity index 58% rename from py-polars/polars/utils/_construction.py rename to py-polars/polars/_utils/construction/dataframe.py index 6351d2659888c..58db91397cc31 100644 --- a/py-polars/polars/utils/_construction.py +++ b/py-polars/polars/_utils/construction/dataframe.py @@ -1,89 +1,65 @@ from __future__ import annotations import contextlib -import warnings from datetime import date, datetime, time, timedelta -from decimal import Decimal as PyDecimal -from functools import lru_cache, singledispatch +from functools import singledispatch from itertools import islice, zip_longest from operator import itemgetter -from sys import version_info from typing import ( TYPE_CHECKING, Any, Callable, Generator, Iterable, - Iterator, Mapping, MutableMapping, Sequence, - get_type_hints, ) import polars._reexport as pl +import polars._utils.construction as plc from polars import functions as F +from polars._utils.construction.utils import ( + contains_nested, + is_namedtuple, + is_pydantic_model, + nt_unpack, + try_get_type_hints, +) +from polars._utils.various import ( + _is_generator, + arrlen, + parse_version, +) +from polars._utils.wrap import wrap_df, wrap_s from polars.datatypes import ( - INTEGER_DTYPES, N_INFER_DEFAULT, - TEMPORAL_DTYPES, - Boolean, Categorical, - Date, - Datetime, - Duration, Enum, - List, - Null, - Object, String, Struct, - Time, - UInt32, Unknown, - dtype_to_py_type, is_polars_dtype, - numpy_char_code_to_dtype, py_type_to_dtype, ) -from polars.datatypes.constructor import ( - numpy_type_to_constructor, - numpy_values_and_dtype, - polars_type_to_constructor, - py_type_to_constructor, -) from polars.dependencies import ( _NUMPY_AVAILABLE, _check_for_numpy, _check_for_pandas, - _check_for_pydantic, dataclasses, - pydantic, ) from polars.dependencies import numpy as np from polars.dependencies import pandas as pd from polars.dependencies import pyarrow as pa -from polars.exceptions import ( - ComputeError, - SchemaError, - ShapeError, - TimeZoneAwareConstructorWarning, -) -from polars.meta import get_index_type, thread_pool_size -from polars.utils._wrap import wrap_df, wrap_s -from polars.utils.various import ( - _is_generator, - arrlen, - find_stacklevel, - parse_version, - range_to_series, -) +from polars.exceptions import ShapeError +from polars.meta import thread_pool_size with contextlib.suppress(ImportError): # Module not available when building docs - from polars.polars import PyDataFrame, PySeries + from polars.polars import PyDataFrame if TYPE_CHECKING: from polars import DataFrame, Series + from polars.polars import PySeries from polars.type_aliases import ( Orientation, PolarsDataType, @@ -92,672 +68,83 @@ ) -def _get_annotations(obj: type) -> dict[str, Any]: - return getattr(obj, "__annotations__", {}) - - -if version_info >= (3, 10): - - def type_hints(obj: type) -> dict[str, Any]: - try: - # often the same as obj.__annotations__, but handles forward references - # encoded as string literals, adds Optional[t] if a default value equal - # to None is set and recursively replaces 'Annotated[T, ...]' with 'T'. - return get_type_hints(obj) - except TypeError: - # fallback on edge-cases (eg: InitVar inference on python 3.10). - return _get_annotations(obj) - -else: - type_hints = _get_annotations - - -@lru_cache(64) -def is_namedtuple(cls: Any, *, annotated: bool = False) -> bool: - """Check whether given class derives from NamedTuple.""" - if all(hasattr(cls, attr) for attr in ("_fields", "_field_defaults", "_replace")): - if not isinstance(cls._fields, property): - if not annotated or len(cls.__annotations__) == len(cls._fields): - return all(isinstance(fld, str) for fld in cls._fields) - return False - - -def is_pydantic_model(value: Any) -> bool: - """Check whether value derives from a pydantic.BaseModel.""" - return _check_for_pydantic(value) and isinstance(value, pydantic.BaseModel) - - -def contains_nested(value: Any, is_nested: Callable[[Any], bool]) -> bool: - """Determine if value contains (or is) nested structured data.""" - if is_nested(value): - return True - elif isinstance(value, dict): - return any(contains_nested(v, is_nested) for v in value.values()) - elif isinstance(value, (list, tuple)): - return any(contains_nested(v, is_nested) for v in value) - return False - - -def include_unknowns( - schema: SchemaDict, cols: Sequence[str] -) -> MutableMapping[str, PolarsDataType]: - """Complete partial schema dict by including Unknown type.""" - return { - col: ( - schema.get(col, Unknown) or Unknown # type: ignore[truthy-bool] - ) - for col in cols - } - - -def nt_unpack(obj: Any) -> Any: - """Recursively unpack a nested NamedTuple.""" - if isinstance(obj, dict): - return {key: nt_unpack(value) for key, value in obj.items()} - elif isinstance(obj, list): - return [nt_unpack(value) for value in obj] - elif is_namedtuple(obj.__class__): - return {key: nt_unpack(value) for key, value in obj._asdict().items()} - elif isinstance(obj, tuple): - return tuple(nt_unpack(value) for value in obj) - else: - return obj - - -################################ -# Series constructor interface # -################################ - - -def series_to_pyseries( - name: str | None, - values: Series, - *, - dtype: PolarsDataType | None = None, - strict: bool = True, -) -> PySeries: - """Construct a new PySeries from a Polars Series.""" - s = values.clone() - if dtype is not None and dtype != s.dtype: - s = s.cast(dtype, strict=strict) - if name is not None: - s = s.alias(name) - return s._s - - -def dataframe_to_pyseries( - name: str | None, - values: DataFrame, - *, - dtype: PolarsDataType | None = None, - strict: bool = True, -) -> PySeries: - """Construct a new PySeries from a Polars DataFrame.""" - if values.width > 1: - name = name or "" - s = values.to_struct(name) - elif values.width == 1: - s = values.to_series() - if name is not None: - s = s.alias(name) - else: - msg = "cannot initialize Series from DataFrame without any columns" - raise TypeError(msg) - - if dtype is not None and dtype != s.dtype: - s = s.cast(dtype, strict=strict) - - return s._s - - -def arrow_to_pyseries(name: str, values: pa.Array, *, rechunk: bool = True) -> PySeries: - """Construct a PySeries from an Arrow array.""" - array = coerce_arrow(values) - - # special handling of empty categorical arrays - if ( - len(array) == 0 - and isinstance(array.type, pa.DictionaryType) - and array.type.value_type - in ( - pa.utf8(), - pa.large_utf8(), - ) - ): - pys = pl.Series(name, [], dtype=Categorical)._s - - elif not hasattr(array, "num_chunks"): - pys = PySeries.from_arrow(name, array) - else: - if array.num_chunks > 1: - # somehow going through ffi with a structarray - # returns the first chunk every time - if isinstance(array.type, pa.StructType): - pys = PySeries.from_arrow(name, array.combine_chunks()) - else: - it = array.iterchunks() - pys = PySeries.from_arrow(name, next(it)) - for a in it: - pys.append(PySeries.from_arrow(name, a)) - elif array.num_chunks == 0: - pys = PySeries.from_arrow(name, pa.array([], array.type)) - else: - pys = PySeries.from_arrow(name, array.chunks[0]) - - if rechunk: - pys.rechunk(in_place=True) - - return pys - - -def numpy_to_pyseries( - name: str, - values: np.ndarray[Any, Any], +def dict_to_pydf( + data: Mapping[str, Sequence[object] | Mapping[str, Sequence[object]] | Series], + schema: SchemaDefinition | None = None, *, + schema_overrides: SchemaDict | None = None, strict: bool = True, nan_to_null: bool = False, -) -> PySeries: - """Construct a PySeries from a numpy array.""" - values = np.ascontiguousarray(values) - - if values.ndim == 1: - values, dtype = numpy_values_and_dtype(values) - constructor = numpy_type_to_constructor(values, dtype) - return constructor( - name, values, nan_to_null if dtype in (np.float32, np.float64) else strict - ) - elif values.ndim == 2: - # Optimize by ingesting 1D and reshaping in Rust - original_shape = values.shape - values = values.reshape(-1) - py_s = numpy_to_pyseries( - name, - values, - strict=strict, - nan_to_null=nan_to_null, - ) - return ( - PyDataFrame([py_s]) - .lazy() - .select([F.col(name).reshape(original_shape)._pyexpr]) - .collect() - .select_at_idx(0) - ) - else: - return PySeries.new_object(name, values, strict) - - -def _get_first_non_none(values: Sequence[Any | None]) -> Any: - """ - Return the first value from a sequence that isn't None. - - If sequence doesn't contain non-None values, return None. - """ - if values is not None: - return next((v for v in values if v is not None), None) - - -def sequence_from_any_value_or_object(name: str, values: Sequence[Any]) -> PySeries: - """ - Last resort conversion. +) -> PyDataFrame: + """Construct a PyDataFrame from a dictionary of sequences.""" + if isinstance(schema, Mapping) and data: + if not all((col in schema) for col in data): + msg = "the given column-schema names do not match the data dictionary" + raise ValueError(msg) + data = {col: data[col] for col in schema} - AnyValues are most flexible and if they fail we go for object types - """ - try: - return PySeries.new_from_any_values(name, values, strict=True) - # raised if we cannot convert to Wrap - except RuntimeError: - return PySeries.new_object(name, values, _strict=False) - # raised if AnyValue fallbacks fail - except SchemaError: - return PySeries.new_object(name, values, _strict=False) - except ComputeError as exc: - if "mixed dtypes" in str(exc): - return PySeries.new_object(name, values, _strict=False) - raise - - -def sequence_from_any_value_and_dtype_or_object( - name: str, values: Sequence[Any], dtype: PolarsDataType -) -> PySeries: - """ - Last resort conversion. + column_names, schema_overrides = _unpack_schema( + schema, lookup_names=data.keys(), schema_overrides=schema_overrides + ) + if not column_names: + column_names = list(data) - AnyValues are most flexible and if they fail we go for object types - """ - try: - return PySeries.new_from_any_values_and_dtype(name, values, dtype, strict=True) - # raised if we cannot convert to Wrap - except RuntimeError: - return PySeries.new_object(name, values, _strict=False) - except ComputeError as exc: - if "mixed dtypes" in str(exc): - return PySeries.new_object(name, values, _strict=False) - raise - - -def iterable_to_pyseries( - name: str, - values: Iterable[Any], - dtype: PolarsDataType | None = None, - *, - chunk_size: int = 1_000_000, - strict: bool = True, -) -> PySeries: - """Construct a PySeries from an iterable/generator.""" - if not isinstance(values, (Generator, Iterator)): - values = iter(values) - - def to_series_chunk(values: list[Any], dtype: PolarsDataType | None) -> Series: - return pl.Series( - name=name, - values=values, - dtype=dtype, - strict=strict, + if data and _NUMPY_AVAILABLE: + # if there are 3 or more numpy arrays of sufficient size, we multi-thread: + count_numpy = sum( + int( + _check_for_numpy(val) + and isinstance(val, np.ndarray) + and len(val) > 1000 + ) + for val in data.values() ) + if count_numpy >= 3: + # yes, multi-threading was easier in python here; we cannot have multiple + # threads running python and release the gil in pyo3 (it will deadlock). - n_chunks = 0 - series: Series = None # type: ignore[assignment] - while True: - slice_values = list(islice(values, chunk_size)) - if not slice_values: - break - schunk = to_series_chunk(slice_values, dtype) - if series is None: - series = schunk - dtype = series.dtype - else: - series.append(schunk) - n_chunks += 1 - - if series is None: - series = to_series_chunk([], dtype) - if n_chunks > 0: - series.rechunk(in_place=True) - - return series._s - - -def _construct_series_with_fallbacks( - constructor: Callable[[str, Sequence[Any], bool], PySeries], - name: str, - values: Sequence[Any], - target_dtype: PolarsDataType | None, - *, - strict: bool, -) -> PySeries: - """Construct Series, with fallbacks for basic type mismatch (eg: bool/int).""" - while True: - try: - return constructor(name, values, strict) - except TypeError as exc: - str_exc = str(exc) - - # from x to float - # error message can be: - # - integers: "'float' object cannot be interpreted as an integer" - if "'float'" in str_exc and ( - # we do not accept float values as int/temporal, as it causes silent - # information loss; the caller should explicitly cast in this case. - target_dtype not in (INTEGER_DTYPES | TEMPORAL_DTYPES) - ): - constructor = py_type_to_constructor(float) - - # from x to string - # error message can be: - # - integers: "'str' object cannot be interpreted as an integer" - # - floats: "must be real number, not str" - elif "'str'" in str_exc or str_exc == "must be real number, not str": - constructor = py_type_to_constructor(str) - - # from x to int - # error message can be: - # - bools: "'int' object cannot be converted to 'PyBool'" - elif str_exc == "'int' object cannot be converted to 'PyBool'": - constructor = py_type_to_constructor(int) - - elif "decimal.Decimal" in str_exc: - constructor = py_type_to_constructor(PyDecimal) - else: - raise - - -def sequence_to_pyseries( - name: str, - values: Sequence[Any], - dtype: PolarsDataType | None = None, - *, - strict: bool = True, - nan_to_null: bool = False, -) -> PySeries: - """Construct a PySeries from a sequence.""" - python_dtype: type | None = None - - if isinstance(values, range): - return range_to_series(name, values, dtype=dtype)._s - - # empty sequence - if not values and dtype is None: - # if dtype for empty sequence could be guessed - # (e.g comparisons between self and other), default to Null - dtype = Null - - # lists defer to subsequent handling; identify nested type - elif dtype == List: - getattr(dtype, "inner", None) - python_dtype = list - - # infer temporal type handling - py_temporal_types = {date, datetime, timedelta, time} - pl_temporal_types = {Date, Datetime, Duration, Time} - - value = _get_first_non_none(values) - if value is not None: - if ( - dataclasses.is_dataclass(value) - or is_pydantic_model(value) - or is_namedtuple(value.__class__) - ) and dtype != Object: - return pl.DataFrame(values).to_struct(name)._s - elif isinstance(value, range): - values = [range_to_series("", v) for v in values] - else: - # for temporal dtypes: - # * if the values are integer, we take the physical branch. - # * if the values are python types, take the temporal branch. - # * if the values are ISO-8601 strings, init then convert via strptime. - # * if the values are floats/other dtypes, this is an error. - if dtype in py_temporal_types and isinstance(value, int): - dtype = py_type_to_dtype(dtype) # construct from integer - elif ( - dtype in pl_temporal_types or type(dtype) in pl_temporal_types - ) and not isinstance(value, int): - python_dtype = dtype_to_py_type(dtype) # type: ignore[arg-type] - - # physical branch - # flat data - if ( - dtype is not None - and dtype not in (List, Struct, Unknown) - and is_polars_dtype(dtype) - and (python_dtype is None) - ): - constructor = polars_type_to_constructor(dtype) - pyseries = _construct_series_with_fallbacks( - constructor, name, values, dtype, strict=strict - ) - if dtype in (Date, Datetime, Duration, Time, Categorical, Boolean, Enum): - if pyseries.dtype() != dtype: - pyseries = pyseries.cast(dtype, strict=True) - return pyseries - - elif dtype == Struct: - struct_schema = dtype.to_schema() if isinstance(dtype, Struct) else None - empty = {} # type: ignore[var-annotated] - return sequence_to_pydf( - data=[(empty if v is None else v) for v in values], - schema=struct_schema, - orient="row", - ).to_struct(name) - else: - if python_dtype is None: - if value is None: - constructor = polars_type_to_constructor(Null) - return constructor(name, values, strict) - - # generic default dtype - python_dtype = type(value) - - # temporal branch - if python_dtype in py_temporal_types: - if dtype is None: - dtype = py_type_to_dtype(python_dtype) # construct from integer - elif dtype in py_temporal_types: - dtype = py_type_to_dtype(dtype) - - values_dtype = ( - None - if value is None - else py_type_to_dtype(type(value), raise_unmatched=False) - ) - if values_dtype is not None and values_dtype.is_float(): - msg = f"'float' object cannot be interpreted as a {python_dtype.__name__!r}" - raise TypeError( - # we do not accept float values as temporal; if this is - # required, the caller should explicitly cast to int first. - msg - ) + # (note: 'dummy' is threaded) + import multiprocessing.dummy - # We use the AnyValue builder to create the datetime array - # We store the values internally as UTC and set the timezone - py_series = PySeries.new_from_any_values(name, values, strict) - time_unit = getattr(dtype, "time_unit", None) - if time_unit is None or values_dtype == Date: - s = wrap_s(py_series) - else: - s = wrap_s(py_series).dt.cast_time_unit(time_unit) - time_zone = getattr(dtype, "time_zone", None) - - if (values_dtype == Date) & (dtype == Datetime): - return s.cast(Datetime(time_unit)).dt.replace_time_zone(time_zone)._s - - if (dtype == Datetime) and ( - value.tzinfo is not None or time_zone is not None - ): - values_tz = str(value.tzinfo) if value.tzinfo is not None else None - dtype_tz = dtype.time_zone # type: ignore[union-attr] - if values_tz is not None and ( - dtype_tz is not None and dtype_tz != "UTC" - ): - msg = ( - "time-zone-aware datetimes are converted to UTC" - "\n\nPlease either drop the time zone from the dtype, or set it to 'UTC'." - " To convert to a different time zone, please use `.dt.convert_time_zone`." - ) - raise ValueError(msg) - if values_tz != "UTC" and dtype_tz is None: - warnings.warn( - "Constructing a Series with time-zone-aware " - "datetimes results in a Series with UTC time zone. " - "To silence this warning, you can filter " - "warnings of class TimeZoneAwareConstructorWarning, or " - "set 'UTC' as the time zone of your datatype.", - TimeZoneAwareConstructorWarning, - stacklevel=find_stacklevel(), + pool_size = thread_pool_size() + with multiprocessing.dummy.Pool(pool_size) as pool: + data = dict( + zip( + column_names, + pool.map( + lambda t: pl.Series(t[0], t[1]) + if isinstance(t[1], np.ndarray) + else t[1], + list(data.items()), + ), ) - return s.dt.replace_time_zone(dtype_tz or "UTC")._s - return s._s - - elif ( - _check_for_numpy(value) - and isinstance(value, np.ndarray) - and len(value.shape) == 1 - ): - n_elems = len(value) - if all(len(v) == n_elems for v in values): - # can take (much) faster path if all lists are the same length - return numpy_to_pyseries( - name, - np.vstack(values), - strict=strict, - nan_to_null=nan_to_null, - ) - else: - return PySeries.new_series_list( - name, - [ - numpy_to_pyseries("", v, strict=strict, nan_to_null=nan_to_null) - for v in values - ], - strict, ) - elif python_dtype in (list, tuple): - if isinstance(dtype, Object): - return PySeries.new_object(name, values, strict) - if dtype: - srs = sequence_from_any_value_and_dtype_or_object(name, values, dtype) - if not dtype.is_(srs.dtype()): - srs = srs.cast(dtype, strict=False) - return srs - return sequence_from_any_value_or_object(name, values) - - elif python_dtype == pl.Series: - return PySeries.new_series_list(name, [v._s for v in values], strict) - - elif python_dtype == PySeries: - return PySeries.new_series_list(name, values, strict) - else: - constructor = py_type_to_constructor(python_dtype) - if constructor == PySeries.new_object: - try: - srs = PySeries.new_from_any_values(name, values, strict) - if _check_for_numpy(python_dtype, check_type=False) and isinstance( - np.bool_(True), np.generic - ): - dtype = numpy_char_code_to_dtype(np.dtype(python_dtype).char) - return srs.cast(dtype, strict=strict) - else: - return srs - - except RuntimeError: - # raised if we cannot convert to Wrap - return sequence_from_any_value_or_object(name, values) - - return _construct_series_with_fallbacks( - constructor, name, values, dtype, strict=strict - ) - - -def _pandas_series_to_arrow( - values: pd.Series[Any] | pd.Index[Any], - *, - length: int | None = None, - nan_to_null: bool = True, -) -> pa.Array: - """ - Convert a pandas Series to an Arrow Array. - - Parameters - ---------- - values : :class:`pandas.Series` or :class:`pandas.Index`. - Series to convert to arrow - nan_to_null : bool, default = True - Interpret `NaN` as missing values. - length : int, optional - in case all values are null, create a null array of this length. - if unset, length is inferred from values. - - Returns - ------- - :class:`pyarrow.Array` - """ - dtype = getattr(values, "dtype", None) - if dtype == "object": - first_non_none = _get_first_non_none(values.values) # type: ignore[arg-type] - if isinstance(first_non_none, str): - return pa.array(values, pa.large_utf8(), from_pandas=nan_to_null) - elif first_non_none is None: - return pa.nulls(length or len(values), pa.large_utf8()) - return pa.array(values, from_pandas=nan_to_null) - elif dtype: - return pa.array(values, from_pandas=nan_to_null) - else: - # Pandas Series is actually a Pandas DataFrame when the original DataFrame - # contains duplicated columns and a duplicated column is requested with df["a"]. - msg = "duplicate column names found: " - raise ValueError( - msg, - f"{values.columns.tolist()!s}", # type: ignore[union-attr] - ) - - -def pandas_to_pyseries( - name: str, - values: pd.Series[Any] | pd.Index[Any] | pd.DatetimeIndex, - *, - nan_to_null: bool = True, -) -> PySeries: - """Construct a PySeries from a pandas Series or DatetimeIndex.""" - # TODO: Change `if not name` to `if name is not None` once name is Optional[str] - if not name and values.name is not None: - name = str(values.name) - return arrow_to_pyseries( - name, _pandas_series_to_arrow(values, nan_to_null=nan_to_null) - ) - - -################################### -# DataFrame constructor interface # -################################### - - -def _handle_columns_arg( - data: list[PySeries], - columns: Sequence[str] | None = None, - *, - from_dict: bool = False, -) -> list[PySeries]: - """Rename data according to columns argument.""" - if not columns: - return data + if not data and schema_overrides: + data_series = [ + pl.Series( + name, [], dtype=schema_overrides.get(name), nan_to_null=nan_to_null + )._s + for name in column_names + ] else: - if not data: - return [pl.Series(c, None)._s for c in columns] - elif len(data) == len(columns): - if from_dict: - series_map = {s.name(): s for s in data} - if all((col in series_map) for col in columns): - return [series_map[col] for col in columns] - for i, c in enumerate(columns): - if c != data[i].name(): - data[i] = data[i].clone() - data[i].rename(c) - return data - else: - msg = f"dimensions of columns arg ({len(columns)}) must match data dimensions ({len(data)})" - raise ValueError(msg) - - -def _post_apply_columns( - pydf: PyDataFrame, - columns: SchemaDefinition | None, - structs: dict[str, Struct] | None = None, - schema_overrides: SchemaDict | None = None, -) -> PyDataFrame: - """Apply 'columns' param *after* PyDataFrame creation (if no alternative).""" - pydf_columns, pydf_dtypes = pydf.columns(), pydf.dtypes() - columns, dtypes = _unpack_schema( - (columns or pydf_columns), schema_overrides=schema_overrides - ) - column_subset: list[str] = [] - if columns != pydf_columns: - if len(columns) < len(pydf_columns) and columns == pydf_columns[: len(columns)]: - column_subset = columns - else: - pydf.set_column_names(columns) - - column_casts = [] - for i, col in enumerate(columns): - dtype = dtypes.get(col) - pydf_dtype = pydf_dtypes[i] - if dtype == Categorical != pydf_dtype: - column_casts.append(F.col(col).cast(Categorical)._pyexpr) - elif dtype == Enum != pydf_dtype: - column_casts.append(F.col(col).cast(dtype)._pyexpr) - elif structs and (struct := structs.get(col)) and struct != pydf_dtype: - column_casts.append(F.col(col).cast(struct)._pyexpr) - elif dtype is not None and dtype != Unknown and dtype != pydf_dtype: - column_casts.append(F.col(col).cast(dtype)._pyexpr) + data_series = [ + s._s + for s in _expand_dict_scalars( + data, + schema_overrides=schema_overrides, + strict=strict, + nan_to_null=nan_to_null, + ).values() + ] - if column_casts or column_subset: - pydf = pydf.lazy() - if column_casts: - pydf = pydf.with_columns(column_casts) - if column_subset: - pydf = pydf.select([F.col(col)._pyexpr for col in column_subset]) - pydf = pydf.collect() + data_series = _handle_columns_arg(data_series, columns=column_names, from_dict=True) + pydf = PyDataFrame(data_series) + if schema_overrides and pydf.dtypes() != list(schema_overrides.values()): + pydf = _post_apply_columns( + pydf, column_names, schema_overrides=schema_overrides, strict=strict + ) return pydf @@ -844,27 +231,83 @@ def _parse_schema_overrides( return column_names, column_dtypes -def _expand_dict_data( - data: Mapping[str, Sequence[object] | Mapping[str, Sequence[object]] | Series], - dtypes: SchemaDict, -) -> Mapping[str, Sequence[object] | Mapping[str, Sequence[object]] | Series]: - """ - Expand any unsized generators/iterators. - - (Note that `range` is sized, and will take a fast-path on Series init). - """ - expanded_data = {} - for name, val in data.items(): - expanded_data[name] = ( - pl.Series(name, val, dtypes.get(name)) if _is_generator(val) else val - ) - return expanded_data +def _handle_columns_arg( + data: list[PySeries], + columns: Sequence[str] | None = None, + *, + from_dict: bool = False, +) -> list[PySeries]: + """Rename data according to columns argument.""" + if columns is None: + return data + elif not data: + return [pl.Series(name=c)._s for c in columns] + elif len(data) != len(columns): + msg = f"dimensions of columns arg ({len(columns)}) must match data dimensions ({len(data)})" + raise ValueError(msg) + + if from_dict: + series_map = {s.name(): s for s in data} + if all((col in series_map) for col in columns): + return [series_map[col] for col in columns] + + for i, c in enumerate(columns): + if c != data[i].name(): + data[i] = data[i].clone() + data[i].rename(c) + + return data + + +def _post_apply_columns( + pydf: PyDataFrame, + columns: SchemaDefinition | None, + structs: dict[str, Struct] | None = None, + schema_overrides: SchemaDict | None = None, + *, + strict: bool = True, +) -> PyDataFrame: + """Apply 'columns' param *after* PyDataFrame creation (if no alternative).""" + pydf_columns, pydf_dtypes = pydf.columns(), pydf.dtypes() + columns, dtypes = _unpack_schema( + (columns or pydf_columns), schema_overrides=schema_overrides + ) + column_subset: list[str] = [] + if columns != pydf_columns: + if len(columns) < len(pydf_columns) and columns == pydf_columns[: len(columns)]: + column_subset = columns + else: + pydf.set_column_names(columns) + + column_casts = [] + for i, col in enumerate(columns): + dtype = dtypes.get(col) + pydf_dtype = pydf_dtypes[i] + if dtype == Categorical != pydf_dtype: + column_casts.append(F.col(col).cast(Categorical, strict=strict)._pyexpr) + elif dtype == Enum != pydf_dtype: + column_casts.append(F.col(col).cast(dtype, strict=strict)._pyexpr) + elif structs and (struct := structs.get(col)) and struct != pydf_dtype: + column_casts.append(F.col(col).cast(struct, strict=strict)._pyexpr) + elif dtype is not None and dtype != Unknown and dtype != pydf_dtype: + column_casts.append(F.col(col).cast(dtype, strict=strict)._pyexpr) + + if column_casts or column_subset: + pydf = pydf.lazy() + if column_casts: + pydf = pydf.with_columns(column_casts) + if column_subset: + pydf = pydf.select([F.col(col)._pyexpr for col in column_subset]) + pydf = pydf.collect() + + return pydf def _expand_dict_scalars( data: Mapping[str, Sequence[object] | Mapping[str, Sequence[object]] | Series], *, schema_overrides: SchemaDict | None = None, + strict: bool = True, order: Sequence[str] | None = None, nan_to_null: bool = False, ) -> dict[str, Series]: @@ -881,23 +324,29 @@ def _expand_dict_scalars( raise TypeError(msg) dtypes = schema_overrides or {} - data = _expand_dict_data(data, dtypes) + data = _expand_dict_data(data, dtypes, strict=strict) array_len = max((arrlen(val) or 0) for val in data.values()) if array_len > 0: for name, val in data.items(): dtype = dtypes.get(name) if isinstance(val, dict) and dtype != Struct: - updated_data[name] = pl.DataFrame(val).to_struct(name) + updated_data[name] = pl.DataFrame(val, strict=strict).to_struct( + name + ) elif isinstance(val, pl.Series): s = val.rename(name) if name != val.name else val if dtype and dtype != s.dtype: - s = s.cast(dtype) + s = s.cast(dtype, strict=strict) updated_data[name] = s elif arrlen(val) is not None or _is_generator(val): updated_data[name] = pl.Series( - name=name, values=val, dtype=dtype, nan_to_null=nan_to_null + name=name, + values=val, + dtype=dtype, + strict=strict, + nan_to_null=nan_to_null, ) elif val is None or isinstance( # type: ignore[redundant-expr] val, (int, float, str, bool, date, datetime, time, timedelta) @@ -907,12 +356,14 @@ def _expand_dict_scalars( ).alias(name) else: updated_data[name] = pl.Series( - name=name, values=[val] * array_len, dtype=dtype + name=name, values=[val] * array_len, dtype=dtype, strict=strict ) elif all((arrlen(val) == 0) for val in data.values()): for name, val in data.items(): - updated_data[name] = pl.Series(name, values=val, dtype=dtypes.get(name)) + updated_data[name] = pl.Series( + name, values=val, dtype=dtypes.get(name), strict=strict + ) elif all((arrlen(val) is None) for val in data.values()): for name, val in data.items(): @@ -920,97 +371,45 @@ def _expand_dict_scalars( name, values=(val if _is_generator(val) else [val]), dtype=dtypes.get(name), + strict=strict, ) if order and list(updated_data) != order: return {col: updated_data.pop(col) for col in order} return updated_data -def dict_to_pydf( +def _expand_dict_data( data: Mapping[str, Sequence[object] | Mapping[str, Sequence[object]] | Series], - schema: SchemaDefinition | None = None, + dtypes: SchemaDict, *, - schema_overrides: SchemaDict | None = None, - nan_to_null: bool = False, -) -> PyDataFrame: - """Construct a PyDataFrame from a dictionary of sequences.""" - if isinstance(schema, Mapping) and data: - if not all((col in schema) for col in data): - msg = "the given column-schema names do not match the data dictionary" - raise ValueError(msg) - data = {col: data[col] for col in schema} - - column_names, schema_overrides = _unpack_schema( - schema, lookup_names=data.keys(), schema_overrides=schema_overrides - ) - if not column_names: - column_names = list(data) - - if data and _NUMPY_AVAILABLE: - # if there are 3 or more numpy arrays of sufficient size, we multi-thread: - count_numpy = sum( - int( - _check_for_numpy(val) - and isinstance(val, np.ndarray) - and len(val) > 1000 - ) - for val in data.values() - ) - if count_numpy >= 3: - # yes, multi-threading was easier in python here; we cannot have multiple - # threads running python and release the gil in pyo3 (it will deadlock). - - # (note: 'dummy' is threaded) - import multiprocessing.dummy - - pool_size = thread_pool_size() - with multiprocessing.dummy.Pool(pool_size) as pool: - data = dict( - zip( - column_names, - pool.map( - lambda t: pl.Series(t[0], t[1]) - if isinstance(t[1], np.ndarray) - else t[1], - list(data.items()), - ), - ) - ) - - if not data and schema_overrides: - data_series = [ - pl.Series( - name, [], dtype=schema_overrides.get(name), nan_to_null=nan_to_null - )._s - for name in column_names - ] - else: - data_series = [ - s._s - for s in _expand_dict_scalars( - data, schema_overrides=schema_overrides, nan_to_null=nan_to_null - ).values() - ] - - data_series = _handle_columns_arg(data_series, columns=column_names, from_dict=True) - pydf = PyDataFrame(data_series) + strict: bool = True, +) -> Mapping[str, Sequence[object] | Mapping[str, Sequence[object]] | Series]: + """ + Expand any unsized generators/iterators. - if schema_overrides and pydf.dtypes() != list(schema_overrides.values()): - pydf = _post_apply_columns( - pydf, column_names, schema_overrides=schema_overrides + (Note that `range` is sized, and will take a fast-path on Series init). + """ + expanded_data = {} + for name, val in data.items(): + expanded_data[name] = ( + pl.Series(name, val, dtypes.get(name), strict=strict) + if _is_generator(val) + else val ) - return pydf + return expanded_data def sequence_to_pydf( data: Sequence[Any], schema: SchemaDefinition | None = None, + *, schema_overrides: SchemaDict | None = None, + strict: bool = True, orient: Orientation | None = None, infer_schema_length: int | None = N_INFER_DEFAULT, ) -> PyDataFrame: """Construct a PyDataFrame from a sequence.""" - if len(data) == 0: + if not data: return dict_to_pydf({}, schema=schema, schema_overrides=schema_overrides) return _sequence_to_pydf_dispatcher( @@ -1018,43 +417,20 @@ def sequence_to_pydf( data=data, schema=schema, schema_overrides=schema_overrides, + strict=strict, orient=orient, infer_schema_length=infer_schema_length, ) -def _sequence_of_series_to_pydf( - first_element: Series, - data: Sequence[Any], - schema: SchemaDefinition | None, - schema_overrides: SchemaDict | None, - **kwargs: Any, -) -> PyDataFrame: - series_names = [s.name for s in data] - column_names, schema_overrides = _unpack_schema( - schema or series_names, - schema_overrides=schema_overrides, - n_expected=len(data), - ) - data_series: list[PySeries] = [] - for i, s in enumerate(data): - if not s.name: - s = s.alias(column_names[i]) - new_dtype = schema_overrides.get(column_names[i]) - if new_dtype and new_dtype != s.dtype: - s = s.cast(new_dtype) - data_series.append(s._s) - - data_series = _handle_columns_arg(data_series, columns=column_names) - return PyDataFrame(data_series) - - @singledispatch def _sequence_to_pydf_dispatcher( first_element: Any, data: Sequence[Any], schema: SchemaDefinition | None, + *, schema_overrides: SchemaDict | None, + strict: bool = True, orient: Orientation | None, infer_schema_length: int | None, ) -> PyDataFrame: @@ -1067,6 +443,7 @@ def _sequence_to_pydf_dispatcher( "data": data, "schema": schema, "schema_overrides": schema_overrides, + "strict": strict, "orient": orient, "infer_schema_length": infer_schema_length, } @@ -1092,10 +469,10 @@ def _sequence_to_pydf_dispatcher( to_pydf = _sequence_of_pandas_to_pydf elif dataclasses.is_dataclass(first_element): - to_pydf = _dataclasses_to_pydf + to_pydf = _sequence_of_dataclasses_to_pydf elif is_pydantic_model(first_element): - to_pydf = _pydantic_models_to_pydf + to_pydf = _sequence_of_pydantic_models_to_pydf else: to_pydf = _sequence_of_elements_to_pydf @@ -1111,7 +488,9 @@ def _sequence_of_sequence_to_pydf( first_element: Sequence[Any] | np.ndarray[Any, Any], data: Sequence[Any], schema: SchemaDefinition | None, + *, schema_overrides: SchemaDict | None, + strict: bool, orient: Orientation | None, infer_schema_length: int | None, ) -> PyDataFrame: @@ -1134,7 +513,9 @@ def _sequence_of_sequence_to_pydf( schema, schema_overrides=schema_overrides, n_expected=len(first_element) ) local_schema_override = ( - include_unknowns(schema_overrides, column_names) if schema_overrides else {} + _include_unknowns(schema_overrides, column_names) + if schema_overrides + else {} ) if ( column_names @@ -1164,7 +545,7 @@ def _sequence_of_sequence_to_pydf( ) if column_names or schema_overrides: pydf = _post_apply_columns( - pydf, column_names, schema_overrides=schema_overrides + pydf, column_names, schema_overrides=schema_overrides, strict=strict ) return pydf @@ -1184,12 +565,42 @@ def _sequence_of_sequence_to_pydf( raise ValueError(msg) +def _sequence_of_series_to_pydf( + first_element: Series, + data: Sequence[Any], + schema: SchemaDefinition | None, + *, + schema_overrides: SchemaDict | None, + strict: bool, + **kwargs: Any, +) -> PyDataFrame: + series_names = [s.name for s in data] + column_names, schema_overrides = _unpack_schema( + schema or series_names, + schema_overrides=schema_overrides, + n_expected=len(data), + ) + data_series: list[PySeries] = [] + for i, s in enumerate(data): + if not s.name: + s = s.alias(column_names[i]) + new_dtype = schema_overrides.get(column_names[i]) + if new_dtype and new_dtype != s.dtype: + s = s.cast(new_dtype, strict=strict) + data_series.append(s._s) + + data_series = _handle_columns_arg(data_series, columns=column_names) + return PyDataFrame(data_series) + + @_sequence_to_pydf_dispatcher.register(tuple) def _sequence_of_tuple_to_pydf( first_element: tuple[Any, ...], data: Sequence[Any], schema: SchemaDefinition | None, + *, schema_overrides: SchemaDict | None, + strict: bool, orient: Orientation | None, infer_schema_length: int | None, ) -> PyDataFrame: @@ -1212,6 +623,7 @@ def _sequence_of_tuple_to_pydf( data=data, schema=schema, schema_overrides=schema_overrides, + strict=strict, orient=orient, infer_schema_length=infer_schema_length, ) @@ -1222,7 +634,9 @@ def _sequence_of_dict_to_pydf( first_element: Any, data: Sequence[Any], schema: SchemaDefinition | None, + *, schema_overrides: SchemaDict | None, + strict: bool, infer_schema_length: int | None, **kwargs: Any, ) -> PyDataFrame: @@ -1230,7 +644,7 @@ def _sequence_of_dict_to_pydf( schema, schema_overrides=schema_overrides ) dicts_schema = ( - include_unknowns(schema_overrides, column_names or list(schema_overrides)) + _include_unknowns(schema_overrides, column_names or list(schema_overrides)) if column_names else None ) @@ -1242,9 +656,7 @@ def _sequence_of_dict_to_pydf( # once https://github.com/pola-rs/polars/issues/11044 is fixed if schema_overrides: pydf = _post_apply_columns( - pydf, - columns=column_names, - schema_overrides=schema_overrides, + pydf, columns=column_names, schema_overrides=schema_overrides, strict=strict ) return pydf @@ -1284,6 +696,8 @@ def _sequence_of_pandas_to_pydf( data: Sequence[Any], schema: SchemaDefinition | None, schema_overrides: SchemaDict | None, + *, + strict: bool, **kwargs: Any, ) -> PyDataFrame: if schema is None: @@ -1297,71 +711,26 @@ def _sequence_of_pandas_to_pydf( data_series: list[PySeries] = [] for i, s in enumerate(data): name = column_names[i] if column_names else s.name - dtype = schema_overrides.get(name, None) - pyseries = pandas_to_pyseries(name=name, values=s) + pyseries = plc.pandas_to_pyseries(name=name, values=s) + dtype = schema_overrides.get(name) if dtype is not None and dtype != pyseries.dtype(): - pyseries = pyseries.cast(dtype, strict=True) + pyseries = pyseries.cast(dtype, strict=strict) data_series.append(pyseries) return PyDataFrame(data_series) -def _establish_dataclass_or_model_schema( - first_element: Any, - schema: SchemaDefinition | None, - schema_overrides: SchemaDict | None, - model_fields: list[str] | None, -) -> tuple[bool, list[str], SchemaDict, SchemaDict]: - """Shared utility code for establishing dataclasses/pydantic model cols/schema.""" - from dataclasses import asdict - - unpack_nested = False - if schema: - column_names, schema_overrides = _unpack_schema( - schema, schema_overrides=schema_overrides - ) - overrides = {col: schema_overrides.get(col, Unknown) for col in column_names} - else: - column_names = [] - overrides = { - col: (py_type_to_dtype(tp, raise_unmatched=False) or Unknown) - for col, tp in type_hints(first_element.__class__).items() - if ((col in model_fields) if model_fields else (col != "__slots__")) - } - if schema_overrides: - overrides.update(schema_overrides) - elif not model_fields: - dc_fields = set(asdict(first_element)) - schema_overrides = overrides = { - nm: tp for nm, tp in overrides.items() if nm in dc_fields - } - else: - schema_overrides = overrides - - for col, tp in overrides.items(): - if tp in (Categorical, Enum): - overrides[col] = String - elif not unpack_nested and (tp.base_type() in (Unknown, Struct)): - unpack_nested = contains_nested( - getattr(first_element, col, None), - is_pydantic_model if model_fields else dataclasses.is_dataclass, # type: ignore[arg-type] - ) - - if model_fields and len(model_fields) == len(overrides): - overrides = dict(zip(model_fields, overrides.values())) - - return unpack_nested, column_names, schema_overrides, overrides - - -def _dataclasses_to_pydf( +def _sequence_of_dataclasses_to_pydf( first_element: Any, data: Sequence[Any], schema: SchemaDefinition | None, schema_overrides: SchemaDict | None, infer_schema_length: int | None, + *, + strict: bool = True, **kwargs: Any, ) -> PyDataFrame: - """Initialise DataFrame from python dataclasses.""" + """Initialize DataFrame from Python dataclasses.""" from dataclasses import asdict, astuple ( @@ -1381,18 +750,22 @@ def _dataclasses_to_pydf( if overrides: structs = {c: tp for c, tp in overrides.items() if isinstance(tp, Struct)} - pydf = _post_apply_columns(pydf, column_names, structs, schema_overrides) + pydf = _post_apply_columns( + pydf, column_names, structs, schema_overrides, strict=strict + ) return pydf -def _pydantic_models_to_pydf( +def _sequence_of_pydantic_models_to_pydf( first_element: Any, data: Sequence[Any], schema: SchemaDefinition | None, schema_overrides: SchemaDict | None, infer_schema_length: int | None, - **kwargs: Any, + *, + strict: bool, + **kwargs: Any, ) -> PyDataFrame: """Initialise DataFrame from pydantic model objects.""" import pydantic # note: must already be available in the env here @@ -1429,9 +802,313 @@ def _pydantic_models_to_pydf( dicts = [md.__dict__ for md in data] pydf = PyDataFrame.read_dicts(dicts, infer_schema_length, overrides) - if overrides: - structs = {c: tp for c, tp in overrides.items() if isinstance(tp, Struct)} - pydf = _post_apply_columns(pydf, column_names, structs, schema_overrides) + if overrides: + structs = {c: tp for c, tp in overrides.items() if isinstance(tp, Struct)} + pydf = _post_apply_columns( + pydf, column_names, structs, schema_overrides, strict=strict + ) + + return pydf + + +def _establish_dataclass_or_model_schema( + first_element: Any, + schema: SchemaDefinition | None, + schema_overrides: SchemaDict | None, + model_fields: list[str] | None, +) -> tuple[bool, list[str], SchemaDict, SchemaDict]: + """Shared utility code for establishing dataclasses/pydantic model cols/schema.""" + from dataclasses import asdict + + unpack_nested = False + if schema: + column_names, schema_overrides = _unpack_schema( + schema, schema_overrides=schema_overrides + ) + overrides = {col: schema_overrides.get(col, Unknown) for col in column_names} + else: + column_names = [] + overrides = { + col: (py_type_to_dtype(tp, raise_unmatched=False) or Unknown) + for col, tp in try_get_type_hints(first_element.__class__).items() + if ((col in model_fields) if model_fields else (col != "__slots__")) + } + if schema_overrides: + overrides.update(schema_overrides) + elif not model_fields: + dc_fields = set(asdict(first_element)) + schema_overrides = overrides = { + nm: tp for nm, tp in overrides.items() if nm in dc_fields + } + else: + schema_overrides = overrides + + for col, tp in overrides.items(): + if tp in (Categorical, Enum): + overrides[col] = String + elif not unpack_nested and (tp.base_type() in (Unknown, Struct)): + unpack_nested = contains_nested( + getattr(first_element, col, None), + is_pydantic_model if model_fields else dataclasses.is_dataclass, # type: ignore[arg-type] + ) + + if model_fields and len(model_fields) == len(overrides): + overrides = dict(zip(model_fields, overrides.values())) + + return unpack_nested, column_names, schema_overrides, overrides + + +def _include_unknowns( + schema: SchemaDict, cols: Sequence[str] +) -> MutableMapping[str, PolarsDataType]: + """Complete partial schema dict by including Unknown type.""" + return { + col: ( + schema.get(col, Unknown) or Unknown # type: ignore[truthy-bool] + ) + for col in cols + } + + +def iterable_to_pydf( + data: Iterable[Any], + schema: SchemaDefinition | None = None, + *, + schema_overrides: SchemaDict | None = None, + strict: bool = True, + orient: Orientation | None = None, + chunk_size: int | None = None, + infer_schema_length: int | None = N_INFER_DEFAULT, +) -> PyDataFrame: + """Construct a PyDataFrame from an iterable/generator.""" + original_schema = schema + column_names: list[str] = [] + dtypes_by_idx: dict[int, PolarsDataType] = {} + if schema is not None: + column_names, schema_overrides = _unpack_schema( + schema, schema_overrides=schema_overrides + ) + elif schema_overrides: + _, schema_overrides = _unpack_schema(schema, schema_overrides=schema_overrides) + + if not isinstance(data, Generator): + data = iter(data) + + if orient == "col": + if column_names and schema_overrides: + dtypes_by_idx = { + idx: schema_overrides.get(col, Unknown) + for idx, col in enumerate(column_names) + } + + return pl.DataFrame( + { + (column_names[idx] if column_names else f"column_{idx}"): pl.Series( + coldata, + dtype=dtypes_by_idx.get(idx), + strict=strict, + ) + for idx, coldata in enumerate(data) + }, + )._df + + def to_frame_chunk(values: list[Any], schema: SchemaDefinition | None) -> DataFrame: + return pl.DataFrame( + data=values, + schema=schema, + strict=strict, + orient="row", + infer_schema_length=infer_schema_length, + ) + + n_chunks = 0 + n_chunk_elems = 1_000_000 + + if chunk_size: + adaptive_chunk_size = chunk_size + elif column_names: + adaptive_chunk_size = n_chunk_elems // len(column_names) + else: + adaptive_chunk_size = None + + df: DataFrame = None # type: ignore[assignment] + chunk_size = max( + (infer_schema_length or 0), + (adaptive_chunk_size or 1000), + ) + while True: + values = list(islice(data, chunk_size)) + if not values: + break + frame_chunk = to_frame_chunk(values, original_schema) + if df is None: + df = frame_chunk + if not original_schema: + original_schema = list(df.schema.items()) + if chunk_size != adaptive_chunk_size: + if (n_columns := len(df.columns)) > 0: + chunk_size = adaptive_chunk_size = n_chunk_elems // n_columns + else: + df.vstack(frame_chunk, in_place=True) + n_chunks += 1 + + if df is None: + df = to_frame_chunk([], original_schema) + + if n_chunks > 0: + df = df.rechunk() + + return df._df + + +def pandas_to_pydf( + data: pd.DataFrame, + schema: SchemaDefinition | None = None, + *, + schema_overrides: SchemaDict | None = None, + strict: bool = True, + rechunk: bool = True, + nan_to_null: bool = True, + include_index: bool = False, +) -> PyDataFrame: + """Construct a PyDataFrame from a pandas DataFrame.""" + arrow_dict = {} + length = data.shape[0] + + if include_index and not _pandas_has_default_index(data): + for idxcol in data.index.names: + arrow_dict[str(idxcol)] = plc.pandas_series_to_arrow( + data.index.get_level_values(idxcol), + nan_to_null=nan_to_null, + length=length, + ) + + for col in data.columns: + arrow_dict[str(col)] = plc.pandas_series_to_arrow( + data[col], nan_to_null=nan_to_null, length=length + ) + + arrow_table = pa.table(arrow_dict) + return arrow_to_pydf( + arrow_table, + schema=schema, + schema_overrides=schema_overrides, + strict=strict, + rechunk=rechunk, + ) + + +def _pandas_has_default_index(df: pd.DataFrame) -> bool: + """Identify if the pandas frame only has a default (or equivalent) index.""" + from pandas.core.indexes.range import RangeIndex + + index_cols = df.index.names + + if len(index_cols) > 1 or index_cols not in ([None], [""]): + # not default: more than one index, or index is named + return False + elif df.index.equals(RangeIndex(start=0, stop=len(df), step=1)): + # is default: simple range index + return True + else: + # finally, is the index _equivalent_ to a default unnamed + # integer index with frame data that was previously sorted + return ( + str(df.index.dtype).startswith("int") + and (df.index.sort_values() == np.arange(len(df))).all() + ) + + +def arrow_to_pydf( + data: pa.Table | pa.RecordBatch, + schema: SchemaDefinition | None = None, + *, + schema_overrides: SchemaDict | None = None, + strict: bool = True, + rechunk: bool = True, +) -> PyDataFrame: + """Construct a PyDataFrame from an Arrow Table or RecordBatch.""" + original_schema = schema + data_column_names = data.schema.names + column_names, schema_overrides = _unpack_schema( + (schema or data_column_names), schema_overrides=schema_overrides + ) + try: + if column_names != data_column_names: + if isinstance(data, pa.RecordBatch): + data = pa.Table.from_batches([data]) + data = data.rename_columns(column_names) + except pa.lib.ArrowInvalid as e: + msg = "dimensions of columns arg must match data dimensions" + raise ValueError(msg) from e + + data_dict = {} + # dictionaries cannot be built in different batches (categorical does not allow + # that) so we rechunk them and create them separately. + dictionary_cols = {} + # struct columns don't work properly if they contain multiple chunks. + struct_cols = {} + names = [] + for i, column in enumerate(data): + # extract the name before casting + name = f"column_{i}" if column._name is None else column._name + names.append(name) + + column = plc.coerce_arrow(column) + if pa.types.is_dictionary(column.type): + ps = plc.arrow_to_pyseries(name, column, rechunk=rechunk) + dictionary_cols[i] = wrap_s(ps) + elif isinstance(column.type, pa.StructType) and column.num_chunks > 1: + ps = plc.arrow_to_pyseries(name, column, rechunk=rechunk) + struct_cols[i] = wrap_s(ps) + else: + data_dict[name] = column + + if len(data_dict) > 0: + tbl = pa.table(data_dict) + + # path for table without rows that keeps datatype + if tbl.shape[0] == 0: + pydf = pl.DataFrame( + [pl.Series(name, c) for (name, c) in zip(tbl.column_names, tbl.columns)] + )._df + else: + pydf = PyDataFrame.from_arrow_record_batches(tbl.to_batches()) + else: + pydf = pl.DataFrame([])._df + if rechunk: + pydf = pydf.rechunk() + + reset_order = False + if len(dictionary_cols) > 0: + df = wrap_df(pydf) + df = df.with_columns([F.lit(s).alias(s.name) for s in dictionary_cols.values()]) + reset_order = True + + if len(struct_cols) > 0: + df = wrap_df(pydf) + df = df.with_columns([F.lit(s).alias(s.name) for s in struct_cols.values()]) + reset_order = True + + if reset_order: + df = df[names] + pydf = df._df + + if column_names != original_schema and (schema_overrides or original_schema): + pydf = _post_apply_columns( + pydf, + original_schema, + schema_overrides=schema_overrides, + strict=strict, + ) + elif schema_overrides: + for col, dtype in zip(pydf.columns(), pydf.dtypes()): + override_dtype = schema_overrides.get(col) + if override_dtype is not None and dtype != override_dtype: + pydf = _post_apply_columns( + pydf, original_schema, schema_overrides=schema_overrides + ) + break return pydf @@ -1442,6 +1119,7 @@ def numpy_to_pydf( *, schema_overrides: SchemaDict | None = None, orient: Orientation | None = None, + strict: bool = True, nan_to_null: bool = False, ) -> PyDataFrame: """Construct a PyDataFrame from a numpy ndarray (including structured ndarrays).""" @@ -1519,6 +1197,7 @@ def numpy_to_pydf( name=series_name, values=data[record_name], dtype=schema_overrides.get(record_name), + strict=strict, nan_to_null=nan_to_null, )._s for series_name, record_name in zip(column_names, record_names) @@ -1532,6 +1211,7 @@ def numpy_to_pydf( name=column_names[0], values=data, dtype=schema_overrides.get(column_names[0]), + strict=strict, nan_to_null=nan_to_null, )._s ] @@ -1546,6 +1226,7 @@ def numpy_to_pydf( else data[:, i] ), dtype=schema_overrides.get(column_names[i]), + strict=strict, nan_to_null=nan_to_null, )._s for i in range(n_columns) @@ -1558,6 +1239,7 @@ def numpy_to_pydf( data if two_d and n_columns == 1 and shape[1] > 1 else data[i] ), dtype=schema_overrides.get(column_names[i]), + strict=strict, nan_to_null=nan_to_null, )._s for i in range(n_columns) @@ -1567,97 +1249,12 @@ def numpy_to_pydf( return PyDataFrame(data_series) -def arrow_to_pydf( - data: pa.Table, - schema: SchemaDefinition | None = None, - *, - schema_overrides: SchemaDict | None = None, - rechunk: bool = True, -) -> PyDataFrame: - """Construct a PyDataFrame from an Arrow Table.""" - original_schema = schema - column_names, schema_overrides = _unpack_schema( - (schema or data.column_names), schema_overrides=schema_overrides - ) - try: - if column_names != data.column_names: - data = data.rename_columns(column_names) - except pa.lib.ArrowInvalid as e: - msg = "dimensions of columns arg must match data dimensions" - raise ValueError(msg) from e - - data_dict = {} - # dictionaries cannot be built in different batches (categorical does not allow - # that) so we rechunk them and create them separately. - dictionary_cols = {} - # struct columns don't work properly if they contain multiple chunks. - struct_cols = {} - names = [] - for i, column in enumerate(data): - # extract the name before casting - name = f"column_{i}" if column._name is None else column._name - names.append(name) - - column = coerce_arrow(column) - if pa.types.is_dictionary(column.type): - ps = arrow_to_pyseries(name, column, rechunk=rechunk) - dictionary_cols[i] = wrap_s(ps) - elif isinstance(column.type, pa.StructType) and column.num_chunks > 1: - ps = arrow_to_pyseries(name, column, rechunk=rechunk) - struct_cols[i] = wrap_s(ps) - else: - data_dict[name] = column - - if len(data_dict) > 0: - tbl = pa.table(data_dict) - - # path for table without rows that keeps datatype - if tbl.shape[0] == 0: - pydf = pl.DataFrame( - [pl.Series(name, c) for (name, c) in zip(tbl.column_names, tbl.columns)] - )._df - else: - pydf = PyDataFrame.from_arrow_record_batches(tbl.to_batches()) - else: - pydf = pl.DataFrame([])._df - if rechunk: - pydf = pydf.rechunk() - - reset_order = False - if len(dictionary_cols) > 0: - df = wrap_df(pydf) - df = df.with_columns([F.lit(s).alias(s.name) for s in dictionary_cols.values()]) - reset_order = True - - if len(struct_cols) > 0: - df = wrap_df(pydf) - df = df.with_columns([F.lit(s).alias(s.name) for s in struct_cols.values()]) - reset_order = True - - if reset_order: - df = df[names] - pydf = df._df - - if column_names != original_schema and (schema_overrides or original_schema): - pydf = _post_apply_columns( - pydf, original_schema, schema_overrides=schema_overrides - ) - elif schema_overrides: - for col, dtype in zip(pydf.columns(), pydf.dtypes()): - override_dtype = schema_overrides.get(col) - if override_dtype is not None and dtype != override_dtype: - pydf = _post_apply_columns( - pydf, original_schema, schema_overrides=schema_overrides - ) - break - - return pydf - - def series_to_pydf( data: Series, schema: SchemaDefinition | None = None, schema_overrides: SchemaDict | None = None, + *, + strict: bool = True, ) -> PyDataFrame: """Construct a PyDataFrame from a Polars Series.""" if schema is None and schema_overrides is None: @@ -1671,16 +1268,18 @@ def series_to_pydf( if schema_overrides: new_dtype = next(iter(schema_overrides.values())) if new_dtype != data.dtype: - data_series[0] = data_series[0].cast(new_dtype, strict=True) + data_series[0] = data_series[0].cast(new_dtype, strict=strict) data_series = _handle_columns_arg(data_series, columns=column_names) return PyDataFrame(data_series) -def frame_to_pydf( +def dataframe_to_pydf( data: DataFrame, schema: SchemaDefinition | None = None, + *, schema_overrides: SchemaDict | None = None, + strict: bool = True, ) -> PyDataFrame: """Construct a PyDataFrame from an existing Polars DataFrame.""" if schema is None and schema_overrides is None: @@ -1694,210 +1293,7 @@ def frame_to_pydf( existing_schema = data.schema for name, new_dtype in schema_overrides.items(): if new_dtype != existing_schema[name]: - data_series[name] = data_series[name].cast(new_dtype, strict=True) + data_series[name] = data_series[name].cast(new_dtype, strict=strict) series_cols = _handle_columns_arg(list(data_series.values()), columns=column_names) return PyDataFrame(series_cols) - - -def iterable_to_pydf( - data: Iterable[Any], - schema: SchemaDefinition | None = None, - schema_overrides: SchemaDict | None = None, - orient: Orientation | None = None, - chunk_size: int | None = None, - infer_schema_length: int | None = N_INFER_DEFAULT, -) -> PyDataFrame: - """Construct a PyDataFrame from an iterable/generator.""" - original_schema = schema - column_names: list[str] = [] - dtypes_by_idx: dict[int, PolarsDataType] = {} - if schema is not None: - column_names, schema_overrides = _unpack_schema( - schema, schema_overrides=schema_overrides - ) - elif schema_overrides: - _, schema_overrides = _unpack_schema(schema, schema_overrides=schema_overrides) - - if not isinstance(data, Generator): - data = iter(data) - - if orient == "col": - if column_names and schema_overrides: - dtypes_by_idx = { - idx: schema_overrides.get(col, Unknown) - for idx, col in enumerate(column_names) - } - - return pl.DataFrame( - { - (column_names[idx] if column_names else f"column_{idx}"): pl.Series( - coldata, dtype=dtypes_by_idx.get(idx) - ) - for idx, coldata in enumerate(data) - } - )._df - - def to_frame_chunk(values: list[Any], schema: SchemaDefinition | None) -> DataFrame: - return pl.DataFrame( - data=values, - schema=schema, - orient="row", - infer_schema_length=infer_schema_length, - ) - - n_chunks = 0 - n_chunk_elems = 1_000_000 - - if chunk_size: - adaptive_chunk_size = chunk_size - elif column_names: - adaptive_chunk_size = n_chunk_elems // len(column_names) - else: - adaptive_chunk_size = None - - df: DataFrame = None # type: ignore[assignment] - chunk_size = max( - (infer_schema_length or 0), - (adaptive_chunk_size or 1000), - ) - while True: - values = list(islice(data, chunk_size)) - if not values: - break - frame_chunk = to_frame_chunk(values, original_schema) - if df is None: - df = frame_chunk - if not original_schema: - original_schema = list(df.schema.items()) - if chunk_size != adaptive_chunk_size: - if (n_columns := len(df.columns)) > 0: - chunk_size = adaptive_chunk_size = n_chunk_elems // n_columns - else: - df.vstack(frame_chunk, in_place=True) - n_chunks += 1 - - if df is None: - df = to_frame_chunk([], original_schema) - - return (df.rechunk() if n_chunks > 0 else df)._df - - -def pandas_has_default_index(df: pd.DataFrame) -> bool: - """Identify if the pandas frame only has a default (or equivalent) index.""" - from pandas.core.indexes.range import RangeIndex - - index_cols = df.index.names - - if len(index_cols) > 1 or index_cols not in ([None], [""]): - # not default: more than one index, or index is named - return False - elif df.index.equals(RangeIndex(start=0, stop=len(df), step=1)): - # is default: simple range index - return True - else: - # finally, is the index _equivalent_ to a default unnamed - # integer index with frame data that was previously sorted - return ( - str(df.index.dtype).startswith("int") - and (df.index.sort_values() == np.arange(len(df))).all() - ) - - -def pandas_to_pydf( - data: pd.DataFrame, - schema: SchemaDefinition | None = None, - *, - schema_overrides: SchemaDict | None = None, - rechunk: bool = True, - nan_to_null: bool = True, - include_index: bool = False, -) -> PyDataFrame: - """Construct a PyDataFrame from a pandas DataFrame.""" - arrow_dict = {} - length = data.shape[0] - - if include_index and not pandas_has_default_index(data): - for idxcol in data.index.names: - arrow_dict[str(idxcol)] = _pandas_series_to_arrow( - data.index.get_level_values(idxcol), - nan_to_null=nan_to_null, - length=length, - ) - - for col in data.columns: - arrow_dict[str(col)] = _pandas_series_to_arrow( - data[col], nan_to_null=nan_to_null, length=length - ) - - arrow_table = pa.table(arrow_dict) - return arrow_to_pydf( - arrow_table, schema=schema, schema_overrides=schema_overrides, rechunk=rechunk - ) - - -def coerce_arrow(array: pa.Array) -> pa.Array: - import pyarrow.compute as pc - - if hasattr(array, "num_chunks") and array.num_chunks > 1: - # small integer keys can often not be combined, so let's already cast - # to the uint32 used by polars - if pa.types.is_dictionary(array.type) and ( - pa.types.is_int8(array.type.index_type) - or pa.types.is_uint8(array.type.index_type) - or pa.types.is_int16(array.type.index_type) - or pa.types.is_uint16(array.type.index_type) - or pa.types.is_int32(array.type.index_type) - ): - array = pc.cast( - array, pa.dictionary(pa.uint32(), pa.large_string()) - ).combine_chunks() - return array - - -def numpy_to_idxs(idxs: np.ndarray[Any, Any], size: int) -> pl.Series: - # Unsigned or signed Numpy array (ordered from fastest to slowest). - # - np.uint32 (polars) or np.uint64 (polars_u64_idx) numpy array - # indexes. - # - Other unsigned numpy array indexes are converted to pl.UInt32 - # (polars) or pl.UInt64 (polars_u64_idx). - # - Signed numpy array indexes are converted pl.UInt32 (polars) or - # pl.UInt64 (polars_u64_idx) after negative indexes are converted - # to absolute indexes. - if idxs.ndim != 1: - msg = "only 1D numpy array is supported as index" - raise ValueError(msg) - - idx_type = get_index_type() - - if len(idxs) == 0: - return pl.Series("", [], dtype=idx_type) - - # Numpy array with signed or unsigned integers. - if idxs.dtype.kind not in ("i", "u"): - msg = "unsupported idxs datatype" - raise NotImplementedError(msg) - - if idx_type == UInt32: - if idxs.dtype in {np.int64, np.uint64} and idxs.max() >= 2**32: - msg = "index positions should be smaller than 2^32" - raise ValueError(msg) - if idxs.dtype == np.int64 and idxs.min() < -(2**32): - msg = "index positions should be bigger than -2^32 + 1" - raise ValueError(msg) - - if idxs.dtype.kind == "i" and idxs.min() < 0: - if idx_type == UInt32: - if idxs.dtype in (np.int8, np.int16): - idxs = idxs.astype(np.int32) - else: - if idxs.dtype in (np.int8, np.int16, np.int32): - idxs = idxs.astype(np.int64) - - # Update negative indexes to absolute indexes. - idxs = np.where(idxs < 0, size + idxs, idxs) - - # numpy conversion is much faster - idxs = idxs.astype(np.uint32) if idx_type == UInt32 else idxs.astype(np.uint64) - - return pl.Series("", idxs, dtype=idx_type) diff --git a/py-polars/polars/_utils/construction/other.py b/py-polars/polars/_utils/construction/other.py new file mode 100644 index 0000000000000..c5a2a06bb7b53 --- /dev/null +++ b/py-polars/polars/_utils/construction/other.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import polars._reexport as pl +from polars._utils.construction.utils import get_first_non_none +from polars.datatypes import UInt32 +from polars.dependencies import numpy as np +from polars.dependencies import pyarrow as pa +from polars.meta import get_index_type + +if TYPE_CHECKING: + from polars import Series + from polars.dependencies import pandas as pd + + +def numpy_to_idxs(idxs: np.ndarray[Any, Any], size: int) -> Series: + # Unsigned or signed Numpy array (ordered from fastest to slowest). + # - np.uint32 (polars) or np.uint64 (polars_u64_idx) numpy array + # indexes. + # - Other unsigned numpy array indexes are converted to pl.UInt32 + # (polars) or pl.UInt64 (polars_u64_idx). + # - Signed numpy array indexes are converted pl.UInt32 (polars) or + # pl.UInt64 (polars_u64_idx) after negative indexes are converted + # to absolute indexes. + if idxs.ndim != 1: + msg = "only 1D numpy array is supported as index" + raise ValueError(msg) + + idx_type = get_index_type() + + if len(idxs) == 0: + return pl.Series("", [], dtype=idx_type) + + # Numpy array with signed or unsigned integers. + if idxs.dtype.kind not in ("i", "u"): + msg = "unsupported idxs datatype" + raise NotImplementedError(msg) + + if idx_type == UInt32: + if idxs.dtype in {np.int64, np.uint64} and idxs.max() >= 2**32: + msg = "index positions should be smaller than 2^32" + raise ValueError(msg) + if idxs.dtype == np.int64 and idxs.min() < -(2**32): + msg = "index positions should be bigger than -2^32 + 1" + raise ValueError(msg) + + if idxs.dtype.kind == "i" and idxs.min() < 0: + if idx_type == UInt32: + if idxs.dtype in (np.int8, np.int16): + idxs = idxs.astype(np.int32) + else: + if idxs.dtype in (np.int8, np.int16, np.int32): + idxs = idxs.astype(np.int64) + + # Update negative indexes to absolute indexes. + idxs = np.where(idxs < 0, size + idxs, idxs) + + # numpy conversion is much faster + idxs = idxs.astype(np.uint32) if idx_type == UInt32 else idxs.astype(np.uint64) + + return pl.Series("", idxs, dtype=idx_type) + + +def pandas_series_to_arrow( + values: pd.Series[Any] | pd.Index[Any], + *, + length: int | None = None, + nan_to_null: bool = True, +) -> pa.Array: + """ + Convert a pandas Series to an Arrow Array. + + Parameters + ---------- + values : :class:`pandas.Series` or :class:`pandas.Index`. + Series to convert to arrow + nan_to_null : bool, default = True + Interpret `NaN` as missing values. + length : int, optional + in case all values are null, create a null array of this length. + if unset, length is inferred from values. + + Returns + ------- + :class:`pyarrow.Array` + """ + dtype = getattr(values, "dtype", None) + if dtype == "object": + first_non_none = get_first_non_none(values.values) # type: ignore[arg-type] + if isinstance(first_non_none, str): + return pa.array(values, pa.large_utf8(), from_pandas=nan_to_null) + elif first_non_none is None: + return pa.nulls(length or len(values), pa.large_utf8()) + return pa.array(values, from_pandas=nan_to_null) + elif dtype: + return pa.array(values, from_pandas=nan_to_null) + else: + # Pandas Series is actually a Pandas DataFrame when the original DataFrame + # contains duplicated columns and a duplicated column is requested with df["a"]. + msg = "duplicate column names found: " + raise ValueError( + msg, + f"{values.columns.tolist()!s}", # type: ignore[union-attr] + ) + + +def coerce_arrow(array: pa.Array) -> pa.Array: + """...""" + import pyarrow.compute as pc + + if hasattr(array, "num_chunks") and array.num_chunks > 1: + # small integer keys can often not be combined, so let's already cast + # to the uint32 used by polars + if pa.types.is_dictionary(array.type) and ( + pa.types.is_int8(array.type.index_type) + or pa.types.is_uint8(array.type.index_type) + or pa.types.is_int16(array.type.index_type) + or pa.types.is_uint16(array.type.index_type) + or pa.types.is_int32(array.type.index_type) + ): + array = pc.cast( + array, pa.dictionary(pa.uint32(), pa.large_string()) + ).combine_chunks() + return array diff --git a/py-polars/polars/_utils/construction/series.py b/py-polars/polars/_utils/construction/series.py new file mode 100644 index 0000000000000..cdcb03d336a59 --- /dev/null +++ b/py-polars/polars/_utils/construction/series.py @@ -0,0 +1,521 @@ +from __future__ import annotations + +import contextlib +import warnings +from datetime import date, datetime, time, timedelta +from decimal import Decimal as PyDecimal +from itertools import islice +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generator, + Iterable, + Iterator, + Sequence, +) + +import polars._reexport as pl +import polars._utils.construction as plc +from polars import functions as F +from polars._utils.construction.utils import ( + get_first_non_none, + is_namedtuple, + is_pydantic_model, +) +from polars._utils.various import ( + find_stacklevel, + range_to_series, +) +from polars._utils.wrap import wrap_s +from polars.datatypes import ( + INTEGER_DTYPES, + TEMPORAL_DTYPES, + Boolean, + Categorical, + Date, + Datetime, + Duration, + Enum, + List, + Null, + Object, + Struct, + Time, + Unknown, + dtype_to_py_type, + is_polars_dtype, + numpy_char_code_to_dtype, + py_type_to_dtype, +) +from polars.datatypes.constructor import ( + numpy_type_to_constructor, + numpy_values_and_dtype, + polars_type_to_constructor, + py_type_to_constructor, +) +from polars.dependencies import ( + _check_for_numpy, + dataclasses, +) +from polars.dependencies import numpy as np +from polars.dependencies import pandas as pd +from polars.dependencies import pyarrow as pa +from polars.exceptions import TimeZoneAwareConstructorWarning + +with contextlib.suppress(ImportError): # Module not available when building docs + from polars.polars import PyDataFrame, PySeries + +if TYPE_CHECKING: + from polars import DataFrame, Series + from polars.dependencies import pandas as pd + from polars.type_aliases import PolarsDataType + + +def sequence_to_pyseries( + name: str, + values: Sequence[Any], + dtype: PolarsDataType | None = None, + *, + strict: bool = True, + nan_to_null: bool = False, +) -> PySeries: + """Construct a PySeries from a sequence.""" + python_dtype: type | None = None + + if isinstance(values, range): + return range_to_series(name, values, dtype=dtype)._s + + # empty sequence + if not values and dtype is None: + # if dtype for empty sequence could be guessed + # (e.g comparisons between self and other), default to Null + dtype = Null + + # lists defer to subsequent handling; identify nested type + elif dtype == List: + python_dtype = list + + # infer temporal type handling + py_temporal_types = {date, datetime, timedelta, time} + pl_temporal_types = {Date, Datetime, Duration, Time} + + value = get_first_non_none(values) + if value is not None: + if ( + dataclasses.is_dataclass(value) + or is_pydantic_model(value) + or is_namedtuple(value.__class__) + ) and dtype != Object: + return pl.DataFrame(values).to_struct(name)._s + elif isinstance(value, range) and dtype is None: + values = [range_to_series("", v) for v in values] + else: + # for temporal dtypes: + # * if the values are integer, we take the physical branch. + # * if the values are python types, take the temporal branch. + # * if the values are ISO-8601 strings, init then convert via strptime. + # * if the values are floats/other dtypes, this is an error. + if dtype in py_temporal_types and isinstance(value, int): + dtype = py_type_to_dtype(dtype) # construct from integer + elif ( + dtype in pl_temporal_types or type(dtype) in pl_temporal_types + ) and not isinstance(value, int): + python_dtype = dtype_to_py_type(dtype) # type: ignore[arg-type] + + # physical branch + # flat data + if ( + dtype is not None + and dtype not in (List, Struct, Unknown) + and is_polars_dtype(dtype) + and (python_dtype is None) + ): + constructor = polars_type_to_constructor(dtype) + pyseries = _construct_series_with_fallbacks( + constructor, name, values, dtype, strict=strict + ) + if dtype in (Date, Datetime, Duration, Time, Categorical, Boolean, Enum): + if pyseries.dtype() != dtype: + pyseries = pyseries.cast(dtype, strict=strict) + return pyseries + + elif dtype == Struct: + struct_schema = dtype.to_schema() if isinstance(dtype, Struct) else None + empty = {} # type: ignore[var-annotated] + return plc.sequence_to_pydf( + data=[(empty if v is None else v) for v in values], + schema=struct_schema, + orient="row", + ).to_struct(name) + else: + if python_dtype is None: + if value is None: + constructor = polars_type_to_constructor(Null) + return constructor(name, values, strict) + + # generic default dtype + python_dtype = type(value) + + # temporal branch + if python_dtype in py_temporal_types: + if dtype is None: + dtype = py_type_to_dtype(python_dtype) # construct from integer + elif dtype in py_temporal_types: + dtype = py_type_to_dtype(dtype) + + values_dtype = ( + None + if value is None + else py_type_to_dtype(type(value), raise_unmatched=False) + ) + if values_dtype is not None and values_dtype.is_float(): + msg = f"'float' object cannot be interpreted as a {python_dtype.__name__!r}" + raise TypeError( + # we do not accept float values as temporal; if this is + # required, the caller should explicitly cast to int first. + msg + ) + + # We use the AnyValue builder to create the datetime array + # We store the values internally as UTC and set the timezone + py_series = PySeries.new_from_any_values(name, values, strict) + + time_unit = getattr(dtype, "time_unit", None) + time_zone = getattr(dtype, "time_zone", None) + + if time_unit is None or values_dtype == Date: + s = wrap_s(py_series) + else: + s = wrap_s(py_series).dt.cast_time_unit(time_unit) + + if (values_dtype == Date) & (dtype == Datetime): + return ( + s.cast(Datetime(time_unit or "us")) + .dt.replace_time_zone(time_zone) + ._s + ) + + if (dtype == Datetime) and ( + value.tzinfo is not None or time_zone is not None + ): + values_tz = str(value.tzinfo) if value.tzinfo is not None else None + dtype_tz = dtype.time_zone # type: ignore[union-attr] + if values_tz is not None and ( + dtype_tz is not None and dtype_tz != "UTC" + ): + msg = ( + "time-zone-aware datetimes are converted to UTC" + "\n\nPlease either drop the time zone from the dtype, or set it to 'UTC'." + " To convert to a different time zone, please use `.dt.convert_time_zone`." + ) + raise ValueError(msg) + if values_tz != "UTC" and dtype_tz is None: + warnings.warn( + "Constructing a Series with time-zone-aware " + "datetimes results in a Series with UTC time zone. " + "To silence this warning, you can filter " + "warnings of class TimeZoneAwareConstructorWarning, or " + "set 'UTC' as the time zone of your datatype.", + TimeZoneAwareConstructorWarning, + stacklevel=find_stacklevel(), + ) + return s.dt.replace_time_zone(dtype_tz or "UTC")._s + return s._s + + elif ( + _check_for_numpy(value) + and isinstance(value, np.ndarray) + and len(value.shape) == 1 + ): + n_elems = len(value) + if all(len(v) == n_elems for v in values): + # can take (much) faster path if all lists are the same length + return numpy_to_pyseries( + name, + np.vstack(values), + strict=strict, + nan_to_null=nan_to_null, + ) + else: + return PySeries.new_series_list( + name, + [ + numpy_to_pyseries("", v, strict=strict, nan_to_null=nan_to_null) + for v in values + ], + strict, + ) + + elif python_dtype in (list, tuple): + if dtype is None: + return PySeries.new_from_any_values(name, values, strict=strict) + elif dtype == Object: + return PySeries.new_object(name, values, strict) + else: + if (inner_dtype := getattr(dtype, "inner", None)) is not None: + pyseries_list = [ + None + if value is None + else sequence_to_pyseries( + "", + value, + inner_dtype, + strict=strict, + nan_to_null=nan_to_null, + ) + for value in values + ] + pyseries = PySeries.new_series_list(name, pyseries_list, strict) + else: + pyseries = PySeries.new_from_any_values_and_dtype( + name, values, dtype, strict=strict + ) + if dtype != pyseries.dtype(): + pyseries = pyseries.cast(dtype, strict=False) + return pyseries + + elif python_dtype == pl.Series: + return PySeries.new_series_list( + name, [v._s if v is not None else None for v in values], strict + ) + + elif python_dtype == PySeries: + return PySeries.new_series_list(name, values, strict) + else: + constructor = py_type_to_constructor(python_dtype) + if constructor == PySeries.new_object: + try: + srs = PySeries.new_from_any_values(name, values, strict) + if _check_for_numpy(python_dtype, check_type=False) and isinstance( + np.bool_(True), np.generic + ): + dtype = numpy_char_code_to_dtype(np.dtype(python_dtype).char) + return srs.cast(dtype, strict=strict) + else: + return srs + + except RuntimeError: + return PySeries.new_from_any_values(name, values, strict=strict) + + return _construct_series_with_fallbacks( + constructor, name, values, dtype, strict=strict + ) + + +def _construct_series_with_fallbacks( + constructor: Callable[[str, Sequence[Any], bool], PySeries], + name: str, + values: Sequence[Any], + target_dtype: PolarsDataType | None, + *, + strict: bool, +) -> PySeries: + """Construct Series, with fallbacks for basic type mismatch (eg: bool/int).""" + while True: + try: + return constructor(name, values, strict) + except TypeError as exc: + str_exc = str(exc) + + # from x to float + # error message can be: + # - integers: "'float' object cannot be interpreted as an integer" + if "'float'" in str_exc and ( + # we do not accept float values as int/temporal, as it causes silent + # information loss; the caller should explicitly cast in this case. + target_dtype not in (INTEGER_DTYPES | TEMPORAL_DTYPES) + ): + constructor = py_type_to_constructor(float) + + # from x to string + # error message can be: + # - integers: "'str' object cannot be interpreted as an integer" + # - floats: "must be real number, not str" + elif "'str'" in str_exc or str_exc == "must be real number, not str": + constructor = py_type_to_constructor(str) + + # from x to int + # error message can be: + # - bools: "'int' object cannot be converted to 'PyBool'" + elif str_exc == "'int' object cannot be converted to 'PyBool'": + constructor = py_type_to_constructor(int) + + elif "decimal.Decimal" in str_exc: + constructor = py_type_to_constructor(PyDecimal) + else: + raise + + +def iterable_to_pyseries( + name: str, + values: Iterable[Any], + dtype: PolarsDataType | None = None, + *, + chunk_size: int = 1_000_000, + strict: bool = True, +) -> PySeries: + """Construct a PySeries from an iterable/generator.""" + if not isinstance(values, (Generator, Iterator)): + values = iter(values) + + def to_series_chunk(values: list[Any], dtype: PolarsDataType | None) -> Series: + return pl.Series( + name=name, + values=values, + dtype=dtype, + strict=strict, + ) + + n_chunks = 0 + series: Series = None # type: ignore[assignment] + while True: + slice_values = list(islice(values, chunk_size)) + if not slice_values: + break + schunk = to_series_chunk(slice_values, dtype) + if series is None: + series = schunk + dtype = series.dtype + else: + series.append(schunk) + n_chunks += 1 + + if series is None: + series = to_series_chunk([], dtype) + if n_chunks > 0: + series.rechunk(in_place=True) + + return series._s + + +def pandas_to_pyseries( + name: str, + values: pd.Series[Any] | pd.Index[Any] | pd.DatetimeIndex, + *, + nan_to_null: bool = True, +) -> PySeries: + """Construct a PySeries from a pandas Series or DatetimeIndex.""" + if not name and values.name is not None: + name = str(values.name) + return arrow_to_pyseries( + name, plc.pandas_series_to_arrow(values, nan_to_null=nan_to_null) + ) + + +def arrow_to_pyseries(name: str, values: pa.Array, *, rechunk: bool = True) -> PySeries: + """Construct a PySeries from an Arrow array.""" + array = plc.coerce_arrow(values) + + # special handling of empty categorical arrays + if ( + len(array) == 0 + and isinstance(array.type, pa.DictionaryType) + and array.type.value_type + in ( + pa.utf8(), + pa.large_utf8(), + ) + ): + pys = pl.Series(name, [], dtype=Categorical)._s + + elif not hasattr(array, "num_chunks"): + pys = PySeries.from_arrow(name, array) + else: + if array.num_chunks > 1: + # somehow going through ffi with a structarray + # returns the first chunk every time + if isinstance(array.type, pa.StructType): + pys = PySeries.from_arrow(name, array.combine_chunks()) + else: + it = array.iterchunks() + pys = PySeries.from_arrow(name, next(it)) + for a in it: + pys.append(PySeries.from_arrow(name, a)) + elif array.num_chunks == 0: + pys = PySeries.from_arrow(name, pa.nulls(0, type=array.type)) + else: + pys = PySeries.from_arrow(name, array.chunks[0]) + + if rechunk: + pys.rechunk(in_place=True) + + return pys + + +def numpy_to_pyseries( + name: str, + values: np.ndarray[Any, Any], + *, + strict: bool = True, + nan_to_null: bool = False, +) -> PySeries: + """Construct a PySeries from a numpy array.""" + values = np.ascontiguousarray(values) + + if values.ndim == 1: + values, dtype = numpy_values_and_dtype(values) + constructor = numpy_type_to_constructor(values, dtype) + return constructor( + name, values, nan_to_null if dtype in (np.float32, np.float64) else strict + ) + elif values.ndim == 2: + # Optimize by ingesting 1D and reshaping in Rust + original_shape = values.shape + values = values.reshape(-1) + py_s = numpy_to_pyseries( + name, + values, + strict=strict, + nan_to_null=nan_to_null, + ) + return ( + PyDataFrame([py_s]) + .lazy() + .select([F.col(name).reshape(original_shape)._pyexpr]) + .collect() + .select_at_idx(0) + ) + else: + return PySeries.new_object(name, values, strict) + + +def series_to_pyseries( + name: str | None, + values: Series, + *, + dtype: PolarsDataType | None = None, + strict: bool = True, +) -> PySeries: + """Construct a new PySeries from a Polars Series.""" + s = values.clone() + if dtype is not None and dtype != s.dtype: + s = s.cast(dtype, strict=strict) + if name is not None: + s = s.alias(name) + return s._s + + +def dataframe_to_pyseries( + name: str | None, + values: DataFrame, + *, + dtype: PolarsDataType | None = None, + strict: bool = True, +) -> PySeries: + """Construct a new PySeries from a Polars DataFrame.""" + if values.width > 1: + name = name or "" + s = values.to_struct(name) + elif values.width == 1: + s = values.to_series() + if name is not None: + s = s.alias(name) + else: + msg = "cannot initialize Series from DataFrame without any columns" + raise TypeError(msg) + + if dtype is not None and dtype != s.dtype: + s = s.cast(dtype, strict=strict) + + return s._s diff --git a/py-polars/polars/_utils/construction/utils.py b/py-polars/polars/_utils/construction/utils.py new file mode 100644 index 0000000000000..dbfc67933273d --- /dev/null +++ b/py-polars/polars/_utils/construction/utils.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +import sys +from functools import lru_cache +from typing import Any, Callable, Sequence, get_type_hints + +from polars.dependencies import _check_for_pydantic, pydantic + + +def _get_annotations(obj: type) -> dict[str, Any]: + return getattr(obj, "__annotations__", {}) + + +if sys.version_info >= (3, 10): + + def try_get_type_hints(obj: type) -> dict[str, Any]: + try: + # often the same as obj.__annotations__, but handles forward references + # encoded as string literals, adds Optional[t] if a default value equal + # to None is set and recursively replaces 'Annotated[T, ...]' with 'T'. + return get_type_hints(obj) + except TypeError: + # fallback on edge-cases (eg: InitVar inference on python 3.10). + return _get_annotations(obj) + +else: + try_get_type_hints = _get_annotations + + +@lru_cache(64) +def is_namedtuple(cls: Any, *, annotated: bool = False) -> bool: + """Check whether given class derives from NamedTuple.""" + if all(hasattr(cls, attr) for attr in ("_fields", "_field_defaults", "_replace")): + if not isinstance(cls._fields, property): + if not annotated or len(cls.__annotations__) == len(cls._fields): + return all(isinstance(fld, str) for fld in cls._fields) + return False + + +def is_pydantic_model(value: Any) -> bool: + """Check whether value derives from a pydantic.BaseModel.""" + return _check_for_pydantic(value) and isinstance(value, pydantic.BaseModel) + + +def get_first_non_none(values: Sequence[Any | None]) -> Any: + """ + Return the first value from a sequence that isn't None. + + If sequence doesn't contain non-None values, return None. + """ + if values is not None: + return next((v for v in values if v is not None), None) + + +def nt_unpack(obj: Any) -> Any: + """Recursively unpack a nested NamedTuple.""" + if isinstance(obj, dict): + return {key: nt_unpack(value) for key, value in obj.items()} + elif isinstance(obj, list): + return [nt_unpack(value) for value in obj] + elif is_namedtuple(obj.__class__): + return {key: nt_unpack(value) for key, value in obj._asdict().items()} + elif isinstance(obj, tuple): + return tuple(nt_unpack(value) for value in obj) + else: + return obj + + +def contains_nested(value: Any, is_nested: Callable[[Any], bool]) -> bool: + """Determine if value contains (or is) nested structured data.""" + if is_nested(value): + return True + elif isinstance(value, dict): + return any(contains_nested(v, is_nested) for v in value.values()) + elif isinstance(value, (list, tuple)): + return any(contains_nested(v, is_nested) for v in value) + return False diff --git a/py-polars/polars/_utils/convert.py b/py-polars/polars/_utils/convert.py new file mode 100644 index 0000000000000..92ae98feb67a1 --- /dev/null +++ b/py-polars/polars/_utils/convert.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +from datetime import date, datetime, time, timedelta, timezone +from decimal import Context +from functools import lru_cache +from typing import ( + TYPE_CHECKING, + Any, + Callable, + NoReturn, + Sequence, + no_type_check, + overload, +) + +from polars.dependencies import _ZONEINFO_AVAILABLE, zoneinfo + +if TYPE_CHECKING: + from datetime import tzinfo + from decimal import Decimal + + from polars.type_aliases import TimeUnit + + +SECONDS_PER_DAY = 86_400 +SECONDS_PER_HOUR = 3_600 +NS_PER_SECOND = 1_000_000_000 +US_PER_SECOND = 1_000_000 +MS_PER_SECOND = 1_000 + +EPOCH_DATE = date(1970, 1, 1) +EPOCH = datetime(1970, 1, 1).replace(tzinfo=None) +EPOCH_UTC = datetime(1970, 1, 1, tzinfo=timezone.utc) + + +@overload +def parse_as_duration_string(td: None) -> None: ... + + +@overload +def parse_as_duration_string(td: timedelta | str) -> str: ... + + +def parse_as_duration_string(td: timedelta | str | None) -> str | None: + """Parse duration input as a Polars duration string.""" + if td is None or isinstance(td, str): + return td + return _timedelta_to_duration_string(td) + + +def _timedelta_to_duration_string(td: timedelta) -> str: + """Convert a Python timedelta object to a Polars duration string.""" + # Positive duration + if td.days >= 0: + d = f"{td.days}d" if td.days != 0 else "" + s = f"{td.seconds}s" if td.seconds != 0 else "" + us = f"{td.microseconds}us" if td.microseconds != 0 else "" + # Negative, whole days + elif td.seconds == 0 and td.microseconds == 0: + return f"{td.days}d" + # Negative, other + else: + corrected_d = td.days + 1 + corrected_seconds = SECONDS_PER_DAY - (td.seconds + (td.microseconds > 0)) + d = f"{corrected_d}d" if corrected_d != 0 else "-" + s = f"{corrected_seconds}s" if corrected_seconds != 0 else "" + us = f"{10**6 - td.microseconds}us" if td.microseconds != 0 else "" + + return f"{d}{s}{us}" + + +def negate_duration_string(duration: str) -> str: + """Negate a Polars duration string.""" + if duration.startswith("-"): + return duration[1:] + else: + return f"-{duration}" + + +def date_to_int(d: date) -> int: + """Convert a Python time object to an integer.""" + return (d - EPOCH_DATE).days + + +def time_to_int(t: time) -> int: + """Convert a Python time object to an integer.""" + t = t.replace(tzinfo=timezone.utc) + seconds = t.hour * SECONDS_PER_HOUR + t.minute * 60 + t.second + microseconds = t.microsecond + return seconds * NS_PER_SECOND + microseconds * 1_000 + + +def datetime_to_int(dt: datetime, time_unit: TimeUnit) -> int: + """Convert a Python datetime object to an integer.""" + # Make sure to use UTC rather than system time zone + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + + td = dt - EPOCH_UTC + seconds = td.days * SECONDS_PER_DAY + td.seconds + microseconds = dt.microsecond + + if time_unit == "us": + return seconds * US_PER_SECOND + microseconds + elif time_unit == "ns": + return seconds * NS_PER_SECOND + microseconds * 1_000 + elif time_unit == "ms": + return seconds * MS_PER_SECOND + microseconds // 1_000 + else: + _raise_invalid_time_unit(time_unit) + + +def timedelta_to_int(td: timedelta, time_unit: TimeUnit) -> int: + """Convert a Python timedelta object to an integer.""" + seconds = td.days * SECONDS_PER_DAY + td.seconds + microseconds = td.microseconds + + if time_unit == "us": + return seconds * US_PER_SECOND + microseconds + elif time_unit == "ns": + return seconds * NS_PER_SECOND + microseconds * 1_000 + elif time_unit == "ms": + return seconds * MS_PER_SECOND + microseconds // 1_000 + else: + _raise_invalid_time_unit(time_unit) + + +@lru_cache(256) +def to_py_date(value: int | float) -> date: + """Convert an integer or float to a Python date object.""" + return EPOCH_DATE + timedelta(days=value) + + +def to_py_time(value: int) -> time: + """Convert an integer to a Python time object.""" + # Fast path for 00:00 + if value == 0: + return time() + + seconds, nanoseconds = divmod(value, NS_PER_SECOND) + minutes, seconds = divmod(seconds, 60) + hours, minutes = divmod(minutes, 60) + return time( + hour=hours, minute=minutes, second=seconds, microsecond=nanoseconds // 1_000 + ) + + +def to_py_datetime( + value: int | float, + time_unit: TimeUnit, + time_zone: str | None = None, +) -> datetime: + """Convert an integer or float to a Python datetime object.""" + if time_unit == "us": + td = timedelta(microseconds=value) + elif time_unit == "ns": + td = timedelta(microseconds=value // 1_000) + elif time_unit == "ms": + td = timedelta(milliseconds=value) + else: + _raise_invalid_time_unit(time_unit) + + if time_zone is None: + return EPOCH + td + elif _ZONEINFO_AVAILABLE: + dt = EPOCH_UTC + td + return _localize_datetime(dt, time_zone) + else: + msg = "install polars[timezone] to handle datetimes with time zone information" + raise ImportError(msg) + + +def _localize_datetime(dt: datetime, time_zone: str) -> datetime: + # zone info installation should already be checked + try: + tz = string_to_zoneinfo(time_zone) + except zoneinfo.ZoneInfoNotFoundError: + # try fixed offset, which is not supported by ZoneInfo + tz = _parse_fixed_tz_offset(time_zone) + + return dt.astimezone(tz) + + +@no_type_check +@lru_cache(None) +def string_to_zoneinfo(key: str) -> Any: + """ + Convert a time zone string to a Python ZoneInfo object. + + This is a simple wrapper for the zoneinfo.ZoneInfo constructor. + The wrapper is useful because zoneinfo is not available on Python 3.8 + and the backports module may not be installed. + """ + return zoneinfo.ZoneInfo(key) + + +# cache here as we have a single tz per column +# and this function will be called on every conversion +@lru_cache(16) +def _parse_fixed_tz_offset(offset: str) -> tzinfo: + try: + # use fromisoformat to parse the offset + dt_offset = datetime.fromisoformat("2000-01-01T00:00:00" + offset) + + # alternatively, we parse the offset ourselves extracting hours and + # minutes, then we can construct: + # tzinfo=timezone(timedelta(hours=..., minutes=...)) + except ValueError: + msg = f"unexpected time zone offset: {offset!r}" + raise ValueError(msg) from None + + return dt_offset.tzinfo # type: ignore[return-value] + + +def to_py_timedelta(value: int | float, time_unit: TimeUnit) -> timedelta: + """Convert an integer or float to a Python timedelta object.""" + if time_unit == "us": + return timedelta(microseconds=value) + elif time_unit == "ns": + return timedelta(microseconds=value // 1_000) + elif time_unit == "ms": + return timedelta(milliseconds=value) + else: + _raise_invalid_time_unit(time_unit) + + +def to_py_decimal(sign: int, digits: Sequence[int], prec: int, scale: int) -> Decimal: + """Convert decimal components to a Python Decimal object.""" + return _create_decimal_with_prec(prec)((sign, digits, scale)) + + +@lru_cache(None) +def _create_decimal_with_prec( + precision: int, +) -> Callable[[tuple[int, Sequence[int], int]], Decimal]: + # pre-cache contexts so we don't have to spend time on recreating them every time + return Context(prec=precision).create_decimal + + +def _raise_invalid_time_unit(time_unit: Any) -> NoReturn: + msg = f"`time_unit` must be one of {{'ms', 'us', 'ns'}}, got {time_unit!r}" + raise ValueError(msg) diff --git a/py-polars/polars/utils/deprecation.py b/py-polars/polars/_utils/deprecation.py similarity index 98% rename from py-polars/polars/utils/deprecation.py rename to py-polars/polars/_utils/deprecation.py index a95d711ebc5f1..aca7df3a78502 100644 --- a/py-polars/polars/utils/deprecation.py +++ b/py-polars/polars/_utils/deprecation.py @@ -5,7 +5,7 @@ from functools import wraps from typing import TYPE_CHECKING, Callable, Sequence, TypeVar -from polars.utils.various import find_stacklevel +from polars._utils.various import find_stacklevel if TYPE_CHECKING: import sys @@ -85,8 +85,7 @@ def deprecate_parameter_as_positional( Use as follows:: @deprecate_parameter_as_positional("column", version="0.20.4") - def myfunc(new_name): - ... + def myfunc(new_name): ... """ def decorate(function: Callable[P, T]) -> Callable[P, T]: @@ -123,8 +122,7 @@ def deprecate_renamed_parameter( Use as follows:: @deprecate_renamed_parameter("old_name", "new_name", version="0.20.4") - def myfunc(new_name): - ... + def myfunc(new_name): ... """ def decorate(function: Callable[P, T]) -> Callable[P, T]: diff --git a/py-polars/polars/utils/_parse_expr_input.py b/py-polars/polars/_utils/parse_expr_input.py similarity index 98% rename from py-polars/polars/utils/_parse_expr_input.py rename to py-polars/polars/_utils/parse_expr_input.py index 05970f11b3674..1aa72d2b92262 100644 --- a/py-polars/polars/utils/_parse_expr_input.py +++ b/py-polars/polars/_utils/parse_expr_input.py @@ -5,8 +5,8 @@ import polars._reexport as pl from polars import functions as F +from polars._utils.deprecation import issue_deprecation_warning from polars.exceptions import ComputeError -from polars.utils.deprecation import issue_deprecation_warning with contextlib.suppress(ImportError): # Module not available when building docs import polars.polars as plr diff --git a/py-polars/polars/utils/_polars_version.py b/py-polars/polars/_utils/polars_version.py similarity index 100% rename from py-polars/polars/utils/_polars_version.py rename to py-polars/polars/_utils/polars_version.py diff --git a/py-polars/polars/utils/_scan.py b/py-polars/polars/_utils/scan.py similarity index 100% rename from py-polars/polars/utils/_scan.py rename to py-polars/polars/_utils/scan.py diff --git a/py-polars/polars/_utils/udfs.py b/py-polars/polars/_utils/udfs.py new file mode 100644 index 0000000000000..b4382b09dabec --- /dev/null +++ b/py-polars/polars/_utils/udfs.py @@ -0,0 +1,981 @@ +"""Utilities related to user defined functions (such as those passed to `apply`).""" + +from __future__ import annotations + +import datetime +import dis +import inspect +import re +import sys +import warnings +from bisect import bisect_left +from collections import defaultdict +from dis import get_instructions +from inspect import signature +from itertools import count, zip_longest +from pathlib import Path +from typing import ( + TYPE_CHECKING, + AbstractSet, + Any, + Callable, + ClassVar, + Iterator, + Literal, + NamedTuple, + Union, +) + +from polars._utils.various import re_escape + +if TYPE_CHECKING: + from dis import Instruction + + if sys.version_info >= (3, 10): + from typing import TypeAlias + else: + from typing_extensions import TypeAlias + + +class StackValue(NamedTuple): + operator: str + operator_arity: int + left_operand: str + right_operand: str + + +MapTarget: TypeAlias = Literal["expr", "frame", "series"] +StackEntry: TypeAlias = Union[str, StackValue] + +_MIN_PY311 = sys.version_info >= (3, 11) +_MIN_PY312 = _MIN_PY311 and sys.version_info >= (3, 12) + + +class OpNames: + BINARY: ClassVar[dict[str, str]] = { + "BINARY_ADD": "+", + "BINARY_AND": "&", + "BINARY_FLOOR_DIVIDE": "//", + "BINARY_LSHIFT": "<<", + "BINARY_RSHIFT": ">>", + "BINARY_MODULO": "%", + "BINARY_MULTIPLY": "*", + "BINARY_OR": "|", + "BINARY_POWER": "**", + "BINARY_SUBTRACT": "-", + "BINARY_TRUE_DIVIDE": "/", + "BINARY_XOR": "^", + } + CALL = frozenset({"CALL"} if _MIN_PY311 else {"CALL_FUNCTION", "CALL_METHOD"}) + CONTROL_FLOW: ClassVar[dict[str, str]] = ( + { + "POP_JUMP_FORWARD_IF_FALSE": "&", + "POP_JUMP_FORWARD_IF_TRUE": "|", + "JUMP_IF_FALSE_OR_POP": "&", + "JUMP_IF_TRUE_OR_POP": "|", + } + # note: 3.12 dropped POP_JUMP_FORWARD_IF_* opcodes + if _MIN_PY311 and not _MIN_PY312 + else { + "POP_JUMP_IF_FALSE": "&", + "POP_JUMP_IF_TRUE": "|", + "JUMP_IF_FALSE_OR_POP": "&", + "JUMP_IF_TRUE_OR_POP": "|", + } + ) + LOAD_VALUES = frozenset(("LOAD_CONST", "LOAD_DEREF", "LOAD_FAST", "LOAD_GLOBAL")) + LOAD_ATTR = frozenset({"LOAD_METHOD", "LOAD_ATTR"}) + LOAD = LOAD_VALUES | LOAD_ATTR + SYNTHETIC: ClassVar[dict[str, int]] = { + "POLARS_EXPRESSION": 1, + } + UNARY: ClassVar[dict[str, str]] = { + "UNARY_NEGATIVE": "-", + "UNARY_POSITIVE": "+", + "UNARY_NOT": "~", + } + PARSEABLE_OPS = frozenset( + {"BINARY_OP", "BINARY_SUBSCR", "COMPARE_OP", "CONTAINS_OP", "IS_OP"} + | set(UNARY) + | set(CONTROL_FLOW) + | set(SYNTHETIC) + | LOAD_VALUES + ) + UNARY_VALUES = frozenset(UNARY.values()) + + +# numpy functions that we can map to native expressions +_NUMPY_MODULE_ALIASES = frozenset(("np", "numpy")) +_NUMPY_FUNCTIONS = frozenset( + ( + # "abs", # TODO: this one clashes with Python builtin abs + "arccos", + "arccosh", + "arcsin", + "arcsinh", + "arctan", + "arctanh", + "cbrt", + "ceil", + "cos", + "cosh", + "degrees", + "exp", + "floor", + "log", + "log10", + "log1p", + "radians", + "sign", + "sin", + "sinh", + "sqrt", + "tan", + "tanh", + ) +) + +# python attrs/funcs that map to native expressions +_PYTHON_ATTRS_MAP = { + "date": "dt.date()", + "day": "dt.day()", + "hour": "dt.hour()", + "microsecond": "dt.microsecond()", + "minute": "dt.minute()", + "month": "dt.month()", + "second": "dt.second()", + "year": "dt.year()", +} +_PYTHON_CASTS_MAP = {"float": "Float64", "int": "Int64", "str": "String"} +_PYTHON_BUILTINS = frozenset(_PYTHON_CASTS_MAP) | {"abs"} +_PYTHON_METHODS_MAP = { + # string + "endswith": "str.ends_with", + "lower": "str.to_lowercase", + "lstrip": "str.strip_chars_start", + "rstrip": "str.strip_chars_end", + "startswith": "str.starts_with", + "strip": "str.strip_chars", + "title": "str.to_titlecase", + "upper": "str.to_uppercase", + # temporal + "date": "dt.date", + "isoweekday": "dt.weekday", + "time": "dt.time", +} + +_MODULE_FUNCTIONS: list[dict[str, list[AbstractSet[str]]]] = [ + # lambda x: numpy.func(x) + # lambda x: numpy.func(CONSTANT) + { + "argument_1_opname": [{"LOAD_FAST", "LOAD_CONST"}], + "argument_2_opname": [], + "module_opname": [OpNames.LOAD_ATTR], + "attribute_opname": [], + "module_name": [_NUMPY_MODULE_ALIASES], + "attribute_name": [], + "function_name": [_NUMPY_FUNCTIONS], + }, + # lambda x: json.loads(x) + { + "argument_1_opname": [{"LOAD_FAST"}], + "argument_2_opname": [], + "module_opname": [OpNames.LOAD_ATTR], + "attribute_opname": [], + "module_name": [{"json"}], + "attribute_name": [], + "function_name": [{"loads"}], + }, + # lambda x: datetime.strptime(x, CONSTANT) + { + "argument_1_opname": [{"LOAD_FAST"}], + "argument_2_opname": [{"LOAD_CONST"}], + "module_opname": [OpNames.LOAD_ATTR], + "attribute_opname": [], + "module_name": [{"datetime"}], + "attribute_name": [], + "function_name": [{"strptime"}], + }, + # lambda x: module.attribute.func(x, CONSTANT) + { + "argument_1_opname": [{"LOAD_FAST"}], + "argument_2_opname": [{"LOAD_CONST"}], + "module_opname": [{"LOAD_ATTR"}], + "attribute_opname": [OpNames.LOAD_ATTR], + "module_name": [{"datetime", "dt"}], + "attribute_name": [{"datetime"}], + "function_name": [{"strptime"}], + }, +] +# In addition to `lambda x: func(x)`, also support cases when a unary operation +# has been applied to `x`, like `lambda x: func(-x)` or `lambda x: func(~x)`. +_MODULE_FUNCTIONS = [ + {**kind, "argument_1_unary_opname": unary} # type: ignore[dict-item] + for kind in _MODULE_FUNCTIONS + for unary in [[set(OpNames.UNARY)], []] +] +_RE_IMPLICIT_BOOL = re.compile(r'pl\.col\("([^"]*)"\) & pl\.col\("\1"\)\.(.+)') + + +def _get_all_caller_variables() -> dict[str, Any]: + """Get all local and global variables from caller's frame.""" + pkg_dir = Path(__file__).parent.parent + + # https://stackoverflow.com/questions/17407119/python-inspect-stack-is-slow + frame = inspect.currentframe() + n = 0 + try: + while frame: + fname = inspect.getfile(frame) + if fname.startswith(str(pkg_dir)): + frame = frame.f_back + n += 1 + else: + break + variables: dict[str, Any] + if frame is None: + variables = {} + else: + variables = {**frame.f_locals, **frame.f_globals} + finally: + # https://docs.python.org/3/library/inspect.html + # > Though the cycle detector will catch these, destruction of the frames + # > (and local variables) can be made deterministic by removing the cycle + # > in a finally clause. + del frame + return variables + + +class BytecodeParser: + """Introspect UDF bytecode and determine if we can rewrite as native expression.""" + + _map_target_name: str | None = None + + def __init__(self, function: Callable[[Any], Any], map_target: MapTarget): + try: + original_instructions = get_instructions(function) + except TypeError: + # in case we hit something that can't be disassembled (eg: code object + # unavailable, like a bare numpy ufunc that isn't in a lambda/function) + original_instructions = iter([]) + + self._function = function + self._map_target = map_target + self._param_name = self._get_param_name(function) + self._rewritten_instructions = RewrittenInstructions( + instructions=original_instructions, + ) + + def _omit_implicit_bool(self, expr: str) -> str: + """Drop extraneous/implied bool (eg: `pl.col("d") & pl.col("d").dt.date()`).""" + while _RE_IMPLICIT_BOOL.search(expr): + expr = _RE_IMPLICIT_BOOL.sub(repl=r'pl.col("\1").\2', string=expr) + return expr + + @staticmethod + def _get_param_name(function: Callable[[Any], Any]) -> str | None: + """Return single function parameter name.""" + try: + # note: we do not parse/handle functions with > 1 params + sig = signature(function) + except ValueError: + return None + return ( + next(iter(parameters.keys())) + if len(parameters := sig.parameters) == 1 + else None + ) + + def _inject_nesting( + self, + expression_blocks: dict[int, str], + logical_instructions: list[Instruction], + ) -> list[tuple[int, str]]: + """Inject nesting boundaries into expression blocks (as parentheses).""" + if logical_instructions: + # reconstruct nesting boundaries for mixed and/or ops by associating control + # flow jump offsets with their target expression blocks and applying parens + if len({inst.opname for inst in logical_instructions}) > 1: + block_offsets: list[int] = list(expression_blocks.keys()) + prev_end = -1 + for inst in logical_instructions: + start = block_offsets[bisect_left(block_offsets, inst.offset) - 1] + end = block_offsets[bisect_left(block_offsets, inst.argval) - 1] + if not (start == 0 and end == block_offsets[-1]): + if prev_end not in (start, end): + expression_blocks[start] = "(" + expression_blocks[start] + expression_blocks[end] += ")" + prev_end = end + + for inst in logical_instructions: # inject connecting "&" and "|" ops + expression_blocks[inst.offset] = OpNames.CONTROL_FLOW[inst.opname] + + return sorted(expression_blocks.items()) + + def _get_target_name(self, col: str, expression: str) -> str: + """The name of the object against which the 'map' is being invoked.""" + if self._map_target_name is not None: + return self._map_target_name + else: + col_expr = f'pl.col("{col}")' + if self._map_target == "expr": + return col_expr + elif self._map_target == "series": + # note: handle overlapping name from global variables; fallback + # through "s", "srs", "series" and (finally) srs0 -> srsN... + search_expr = expression.replace(col_expr, "") + for name in ("s", "srs", "series"): + if not re.search(rf"\b{name}\b", search_expr): + self._map_target_name = name + return name + n = count() + while True: + name = f"srs{next(n)}" + if not re.search(rf"\b{name}\b", search_expr): + self._map_target_name = name + return name + + msg = f"TODO: map_target = {self._map_target!r}" + raise NotImplementedError(msg) + + @property + def map_target(self) -> MapTarget: + """The map target, eg: one of 'expr', 'frame', or 'series'.""" + return self._map_target + + def can_attempt_rewrite(self) -> bool: + """ + Determine if we may be able to offer a native polars expression instead. + + Note that `lambda x: x` is inefficient, but we ignore it because it is not + guaranteed that using the equivalent bare constant value will return the + same output. (Hopefully nobody is writing lambdas like that anyway...) + """ + return ( + self._param_name is not None + # check minimum number of ops, ensuring all are parseable + and len(self._rewritten_instructions) >= 2 + and all( + inst.opname in OpNames.PARSEABLE_OPS + for inst in self._rewritten_instructions + ) + # exclude constructs/functions with multiple RETURN_VALUE ops + and sum( + 1 + for inst in self.original_instructions + if inst.opname == "RETURN_VALUE" + ) + == 1 + ) + + def dis(self) -> None: + """Print disassembled function bytecode.""" + dis.dis(self._function) + + @property + def function(self) -> Callable[[Any], Any]: + """The function being parsed.""" + return self._function + + @property + def original_instructions(self) -> list[Instruction]: + """The original bytecode instructions from the function we are parsing.""" + return list(self._rewritten_instructions._original_instructions) + + @property + def param_name(self) -> str | None: + """The parameter name of the function being parsed.""" + return self._param_name + + @property + def rewritten_instructions(self) -> list[Instruction]: + """The rewritten bytecode instructions from the function we are parsing.""" + return list(self._rewritten_instructions) + + def to_expression(self, col: str) -> str | None: + """Translate postfix bytecode instructions to polars expression/string.""" + self._map_target_name = None + if self._param_name is None: + return None + + # decompose bytecode into logical 'and'/'or' expression blocks (if present) + control_flow_blocks = defaultdict(list) + logical_instructions = [] + jump_offset = 0 + for idx, inst in enumerate(self._rewritten_instructions): + if inst.opname in OpNames.CONTROL_FLOW: + jump_offset = self._rewritten_instructions[idx + 1].offset + logical_instructions.append(inst) + else: + control_flow_blocks[jump_offset].append(inst) + + # convert each block to a polars expression string + caller_variables: dict[str, Any] = {} + try: + expression_strings = self._inject_nesting( + { + offset: InstructionTranslator( + instructions=ops, + caller_variables=caller_variables, + map_target=self._map_target, + ).to_expression( + col=col, + param_name=self._param_name, + depth=int(bool(logical_instructions)), + ) + for offset, ops in control_flow_blocks.items() + }, + logical_instructions, + ) + except NotImplementedError: + return None + polars_expr = " ".join(expr for _offset, expr in expression_strings) + + # note: if no 'pl.col' in the expression, it likely represents a compound + # constant value (e.g. `lambda x: CONST + 123`), so we don't want to warn + if "pl.col(" not in polars_expr: + return None + else: + polars_expr = self._omit_implicit_bool(polars_expr) + if self._map_target == "series": + target_name = self._get_target_name(col, polars_expr) + return polars_expr.replace(f'pl.col("{col}")', target_name) + else: + return polars_expr + + def warn( + self, + col: str, + suggestion_override: str | None = None, + udf_override: str | None = None, + ) -> None: + """Generate warning that suggests an equivalent native polars expression.""" + # Import these here so that udfs can be imported without polars installed. + + from polars._utils.various import ( + find_stacklevel, + in_terminal_that_supports_colour, + ) + from polars.exceptions import PolarsInefficientMapWarning + + suggested_expression = suggestion_override or self.to_expression(col) + + if suggested_expression is not None: + target_name = self._get_target_name(col, suggested_expression) + func_name = udf_override or self._function.__name__ or "..." + if func_name == "": + func_name = f"lambda {self._param_name}: ..." + + addendum = ( + 'Note: in list.eval context, pl.col("") should be written as pl.element()' + if 'pl.col("")' in suggested_expression + else "" + ) + if self._map_target == "expr": + apitype = "expressions" + clsname = "Expr" + else: + apitype = "series" + clsname = "Series" + + before, after = ( + ( + f" \033[31m- {target_name}.map_elements({func_name})\033[0m\n", + f" \033[32m+ {suggested_expression}\033[0m\n{addendum}", + ) + if in_terminal_that_supports_colour() + else ( + f" - {target_name}.map_elements({func_name})\n", + f" + {suggested_expression}\n{addendum}", + ) + ) + warnings.warn( + f"\n{clsname}.map_elements is significantly slower than the native {apitype} API.\n" + "Only use if you absolutely CANNOT implement your logic otherwise.\n" + "Replace this expression...\n" + f"{before}" + "with this one instead:\n" + f"{after}", + PolarsInefficientMapWarning, + stacklevel=find_stacklevel(), + ) + + +class InstructionTranslator: + """Translates Instruction bytecode to a polars expression string.""" + + def __init__( + self, + instructions: list[Instruction], + caller_variables: dict[str, Any], + map_target: MapTarget, + ) -> None: + self._caller_variables: dict[str, Any] = caller_variables + self._stack = self._to_intermediate_stack(instructions, map_target) + + def to_expression(self, col: str, param_name: str, depth: int) -> str: + """Convert intermediate stack to polars expression string.""" + return self._expr(self._stack, col, param_name, depth) + + @staticmethod + def op(inst: Instruction) -> str: + """Convert bytecode instruction to suitable intermediate op string.""" + if inst.opname in OpNames.CONTROL_FLOW: + return OpNames.CONTROL_FLOW[inst.opname] + elif inst.argrepr: + return inst.argrepr + elif inst.opname == "IS_OP": + return "is not" if inst.argval else "is" + elif inst.opname == "CONTAINS_OP": + return "not in" if inst.argval else "in" + elif inst.opname in OpNames.UNARY: + return OpNames.UNARY[inst.opname] + elif inst.opname == "BINARY_SUBSCR": + return "replace" + else: + msg = ( + "unrecognized opname" + "\n\nPlease report a bug to https://github.com/pola-rs/polars/issues" + " with the content of function you were passing to `map` and the" + f" following instruction object:\n{inst!r}" + ) + raise AssertionError(msg) + + def _expr(self, value: StackEntry, col: str, param_name: str, depth: int) -> str: + """Take stack entry value and convert to polars expression string.""" + if isinstance(value, StackValue): + op = value.operator + e1 = self._expr(value.left_operand, col, param_name, depth + 1) + if value.operator_arity == 1: + if op not in OpNames.UNARY_VALUES: + if e1.startswith("pl.col("): + call = "" if op.endswith(")") else "()" + return f"{e1}.{op}{call}" + if e1[0] in OpNames.UNARY_VALUES and e1[1:].startswith("pl.col("): + call = "" if op.endswith(")") else "()" + return f"({e1}).{op}{call}" + + # support use of consts as numpy/builtin params, eg: + # "np.sin(3) + np.cos(x)", or "len('const_string') + len(x)" + pfx = "np." if op in _NUMPY_FUNCTIONS else "" + return f"{pfx}{op}({e1})" + return f"{op}{e1}" + else: + e2 = self._expr(value.right_operand, col, param_name, depth + 1) + if op in ("is", "is not") and value[2] == "None": + not_ = "" if op == "is" else "not_" + return f"{e1}.is_{not_}null()" + elif op in ("in", "not in"): + not_ = "" if op == "in" else "~" + return ( + f"{not_}({e1}.is_in({e2}))" + if " " in e1 + else f"{not_}{e1}.is_in({e2})" + ) + elif op == "replace": + if not self._caller_variables: + self._caller_variables.update(_get_all_caller_variables()) + if not isinstance(self._caller_variables.get(e1, None), dict): + msg = "require dict mapping" + raise NotImplementedError(msg) + return f"{e2}.{op}({e1})" + elif op == "<<": + # Result of 2**e2 might be float is e2 was negative. + # But, if e1 << e2 was valid, then e2 must have been positive. + # Hence, the output of 2**e2 can be safely cast to Int64, which + # may be necessary if chaining operations which assume Int64 output. + return f"({e1} * 2**{e2}).cast(pl.Int64)" + elif op == ">>": + # Motivation for the cast is the same as in the '<<' case above. + return f"({e1} / 2**{e2}).cast(pl.Int64)" + else: + expr = f"{e1} {op} {e2}" + return f"({expr})" if depth else expr + + elif value == param_name: + return f'pl.col("{col}")' + + return value + + def _to_intermediate_stack( + self, instructions: list[Instruction], map_target: MapTarget + ) -> StackEntry: + """Take postfix bytecode and convert to an intermediate natural-order stack.""" + if map_target in ("expr", "series"): + stack: list[StackEntry] = [] + for inst in instructions: + stack.append( + inst.argrepr + if inst.opname in OpNames.LOAD + else ( + StackValue( + operator=self.op(inst), + operator_arity=1, + left_operand=stack.pop(), # type: ignore[arg-type] + right_operand=None, # type: ignore[arg-type] + ) + if ( + inst.opname in OpNames.UNARY + or OpNames.SYNTHETIC.get(inst.opname) == 1 + ) + else StackValue( + operator=self.op(inst), + operator_arity=2, + left_operand=stack.pop(-2), # type: ignore[arg-type] + right_operand=stack.pop(-1), # type: ignore[arg-type] + ) + ) + ) + return stack[0] + + # TODO: dataframe.apply(...) + msg = f"TODO: {map_target!r} apply" + raise NotImplementedError(msg) + + +class RewrittenInstructions: + """ + Standalone class that applies Instruction rewrite/filtering rules. + + This significantly simplifies subsequent parsing by injecting + synthetic POLARS_EXPRESSION ops into the Instruction stream for + easy identification/translation and separates the parsing logic + from the identification of expression translation opportunities. + """ + + _ignored_ops = frozenset( + [ + "COPY", + "COPY_FREE_VARS", + "POP_TOP", + "PRECALL", + "PUSH_NULL", + "RESUME", + "RETURN_VALUE", + ] + ) + _caller_variables: ClassVar[dict[str, Any]] = {} + + def __init__(self, instructions: Iterator[Instruction]): + self._original_instructions = list(instructions) + self._rewritten_instructions = self._rewrite( + self._upgrade_instruction(inst) + for inst in self._original_instructions + if inst.opname not in self._ignored_ops + ) + + def __len__(self) -> int: + return len(self._rewritten_instructions) + + def __iter__(self) -> Iterator[Instruction]: + return iter(self._rewritten_instructions) + + def __getitem__(self, item: Any) -> Instruction: + return self._rewritten_instructions[item] + + def _matches( + self, + idx: int, + *, + opnames: list[AbstractSet[str]], + argvals: list[AbstractSet[Any] | dict[Any, Any] | None] | None, + is_attr: bool = False, + ) -> list[Instruction]: + """ + Check if a sequence of Instructions matches the specified ops/argvals. + + Parameters + ---------- + idx + The index of the first instruction to check. + opnames + The full opname sequence that defines a match. + argvals + Associated argvals that must also match (in same position as opnames). + is_attr + Indicate if the match represents pure attribute access (cannot be called). + """ + n_required_ops, argvals = len(opnames), argvals or [] + idx_offset = idx + n_required_ops + if ( + is_attr + and (trailing_inst := self._instructions[idx_offset : idx_offset + 1]) + and trailing_inst[0].opname in OpNames.CALL # not pure attr if called + ): + return [] + + instructions = self._instructions[idx:idx_offset] + if len(instructions) == n_required_ops and all( + inst.opname in match_opnames + and (match_argval is None or inst.argval in match_argval) + for inst, match_opnames, match_argval in zip_longest( + instructions, opnames, argvals + ) + ): + return instructions + return [] + + def _rewrite(self, instructions: Iterator[Instruction]) -> list[Instruction]: + """ + Apply rewrite rules, potentially injecting synthetic operations. + + Rules operate on the instruction stream and can examine/modify + it as needed, pushing updates into "updated_instructions" and + returning True/False to indicate if any changes were made. + """ + self._instructions = list(instructions) + updated_instructions: list[Instruction] = [] + idx = 0 + while idx < len(self._instructions): + inst, increment = self._instructions[idx], 1 + if inst.opname not in OpNames.LOAD or not any( + (increment := map_rewrite(idx, updated_instructions)) + for map_rewrite in ( + # add any other rewrite methods here + self._rewrite_functions, + self._rewrite_methods, + self._rewrite_builtins, + self._rewrite_attrs, + ) + ): + updated_instructions.append(inst) + idx += increment or 1 + return updated_instructions + + def _rewrite_attrs(self, idx: int, updated_instructions: list[Instruction]) -> int: + """Replace python attribute lookup with synthetic POLARS_EXPRESSION op.""" + if matching_instructions := self._matches( + idx, + opnames=[{"LOAD_FAST"}, {"LOAD_ATTR"}], + argvals=[None, _PYTHON_ATTRS_MAP], + is_attr=True, + ): + inst = matching_instructions[1] + expr_name = _PYTHON_ATTRS_MAP[inst.argval] + px = inst._replace( + opname="POLARS_EXPRESSION", argval=expr_name, argrepr=expr_name + ) + updated_instructions.extend([matching_instructions[0], px]) + + return len(matching_instructions) + + def _rewrite_builtins( + self, idx: int, updated_instructions: list[Instruction] + ) -> int: + """Replace builtin function calls with a synthetic POLARS_EXPRESSION op.""" + if matching_instructions := self._matches( + idx, + opnames=[{"LOAD_GLOBAL"}, {"LOAD_FAST", "LOAD_CONST"}, OpNames.CALL], + argvals=[_PYTHON_BUILTINS], + ): + inst1, inst2 = matching_instructions[:2] + if (argval := inst1.argval) in _PYTHON_CASTS_MAP: + dtype = _PYTHON_CASTS_MAP[argval] + argval = f"cast(pl.{dtype})" + + px = inst1._replace( + opname="POLARS_EXPRESSION", + argval=argval, + argrepr=argval, + offset=inst2.offset, + ) + # POLARS_EXPRESSION is mapped as a unary op, so switch instruction order + operand = inst2._replace(offset=inst1.offset) + updated_instructions.extend((operand, px)) + + return len(matching_instructions) + + def _rewrite_functions( + self, idx: int, updated_instructions: list[Instruction] + ) -> int: + """Replace function calls with a synthetic POLARS_EXPRESSION op.""" + for function_kind in _MODULE_FUNCTIONS: + opnames: list[AbstractSet[str]] = [ + {"LOAD_GLOBAL", "LOAD_DEREF"}, + *function_kind["module_opname"], + *function_kind["attribute_opname"], + *function_kind["argument_1_opname"], + *function_kind["argument_1_unary_opname"], + *function_kind["argument_2_opname"], + OpNames.CALL, + ] + if matching_instructions := self._matches( + idx, + opnames=opnames, + argvals=[ + *function_kind["module_name"], + *function_kind["attribute_name"], + *function_kind["function_name"], + ], + ): + attribute_count = len(function_kind["attribute_name"]) + inst1, inst2, inst3 = matching_instructions[ + attribute_count : 3 + attribute_count + ] + if inst1.argval == "json": + expr_name = "str.json_decode" + elif inst1.argval == "datetime": + fmt = matching_instructions[attribute_count + 3].argval + expr_name = f'str.to_datetime(format="{fmt}")' + if not self._is_stdlib_datetime( + inst1.argval, + matching_instructions[0].argval, + fmt, + attribute_count, + ): + return 0 + else: + expr_name = inst2.argval + + px = inst1._replace( + opname="POLARS_EXPRESSION", + argval=expr_name, + argrepr=expr_name, + offset=inst3.offset, + ) + + # POLARS_EXPRESSION is mapped as a unary op, so switch instruction order + operand = inst3._replace(offset=inst1.offset) + updated_instructions.extend( + ( + operand, + matching_instructions[3 + attribute_count], + px, + ) + if function_kind["argument_1_unary_opname"] + else (operand, px) + ) + return len(matching_instructions) + + return 0 + + def _rewrite_methods( + self, idx: int, updated_instructions: list[Instruction] + ) -> int: + """Replace python method calls with synthetic POLARS_EXPRESSION op.""" + LOAD_METHOD = OpNames.LOAD_ATTR if _MIN_PY312 else {"LOAD_METHOD"} + if matching_instructions := ( + # method call with one basic arg, eg: "s.endswith('!')" + self._matches( + idx, + opnames=[LOAD_METHOD, {"LOAD_CONST"}, OpNames.CALL], + argvals=[_PYTHON_METHODS_MAP], + ) + or + # method call with no arg, eg: "s.lower()" + self._matches( + idx, + opnames=[LOAD_METHOD, OpNames.CALL], + argvals=[_PYTHON_METHODS_MAP], + ) + ): + inst = matching_instructions[0] + expr = _PYTHON_METHODS_MAP[inst.argval] + + if matching_instructions[1].opname == "LOAD_CONST": + param_value = matching_instructions[1].argval + if isinstance(param_value, tuple) and expr in ( + "str.starts_with", + "str.ends_with", + ): + starts, ends = ("^", "") if "starts" in expr else ("", "$") + rx = "|".join(re_escape(v) for v in param_value) + q = '"' if "'" in param_value else "'" + expr = f"str.contains(r{q}{starts}({rx}){ends}{q})" + else: + expr += f"({param_value!r})" + + px = inst._replace(opname="POLARS_EXPRESSION", argval=expr, argrepr=expr) + updated_instructions.append(px) + + return len(matching_instructions) + + @staticmethod + def _upgrade_instruction(inst: Instruction) -> Instruction: + """Rewrite any older binary opcodes using py 3.11 'BINARY_OP' instead.""" + if not _MIN_PY311 and inst.opname in OpNames.BINARY: + inst = inst._replace( + argrepr=OpNames.BINARY[inst.opname], + opname="BINARY_OP", + ) + return inst + + def _is_stdlib_datetime( + self, function_name: str, module_name: str, fmt: str, attribute_count: int + ) -> bool: + if not self._caller_variables: + self._caller_variables.update(_get_all_caller_variables()) + vars = self._caller_variables + return ( + attribute_count == 0 and vars.get(function_name) is datetime.datetime + ) or (attribute_count == 1 and vars.get(module_name) is datetime) + + +def _is_raw_function(function: Callable[[Any], Any]) -> tuple[str, str]: + """Identify translatable calls that aren't wrapped inside a lambda/function.""" + try: + func_module = function.__class__.__module__ + func_name = function.__name__ + except AttributeError: + return "", "" + + # numpy function calls + if func_module == "numpy" and func_name in _NUMPY_FUNCTIONS: + return "np", f"{func_name}()" + + # python function calls + elif func_module == "builtins": + if func_name in _PYTHON_CASTS_MAP: + return "builtins", f"cast(pl.{_PYTHON_CASTS_MAP[func_name]})" + elif func_name == "loads": + import json # double-check since it is referenced via 'builtins' + + if function is json.loads: + return "json", "str.json_decode()" + + return "", "" + + +def warn_on_inefficient_map( + function: Callable[[Any], Any], columns: list[str], map_target: MapTarget +) -> None: + """ + Generate `PolarsInefficientMapWarning` on poor usage of a `map` function. + + Parameters + ---------- + function + The function passed to `map`. + columns + The column names of the original object; in the case of an `Expr` this + will be a list of length 1 containing the expression's root name. + map_target + The target of the `map` call. One of `"expr"`, `"frame"`, + or `"series"`. + """ + if map_target == "frame": + msg = "TODO: 'frame' map-function parsing" + raise NotImplementedError(msg) + + # note: we only consider simple functions with a single col/param + if not (col := columns and columns[0]): + return None + + # the parser introspects function bytecode to determine if we can + # rewrite as a much more optimal native polars expression instead + parser = BytecodeParser(function, map_target) + if parser.can_attempt_rewrite(): + parser.warn(col) + else: + # handle bare numpy/json functions + module, suggestion = _is_raw_function(function) + if module and suggestion: + fn = function.__name__ + parser.warn( + col, + suggestion_override=f'pl.col("{col}").{suggestion}', + udf_override=fn if module == "builtins" else f"{module}.{fn}", + ) + + +__all__ = ["BytecodeParser", "warn_on_inefficient_map"] diff --git a/py-polars/polars/utils/unstable.py b/py-polars/polars/_utils/unstable.py similarity index 97% rename from py-polars/polars/utils/unstable.py rename to py-polars/polars/_utils/unstable.py index e00c9177e06b3..3ad2e4fde3065 100644 --- a/py-polars/polars/utils/unstable.py +++ b/py-polars/polars/_utils/unstable.py @@ -6,8 +6,8 @@ from functools import wraps from typing import TYPE_CHECKING, Callable, TypeVar +from polars._utils.various import find_stacklevel from polars.exceptions import UnstableWarning -from polars.utils.various import find_stacklevel if TYPE_CHECKING: import sys diff --git a/py-polars/polars/utils/various.py b/py-polars/polars/_utils/various.py similarity index 98% rename from py-polars/polars/utils/various.py rename to py-polars/polars/_utils/various.py index 1aa2f242d219e..e689cf2ea632c 100644 --- a/py-polars/polars/utils/various.py +++ b/py-polars/polars/_utils/various.py @@ -385,7 +385,7 @@ def str_duration_(td: str | None) -> int | None: NS = TypeVar("NS") -class sphinx_accessor(property): # noqa: D101 +class sphinx_accessor(property): def __get__( # type: ignore[override] self, instance: Any, @@ -562,3 +562,11 @@ def parse_percentiles( at_or_above_50_percentiles = [0.5, *at_or_above_50_percentiles] return [*sub_50_percentiles, *at_or_above_50_percentiles] + + +def re_escape(s: str) -> str: + """Escape a string for use in a Polars (Rust) regex.""" + # note: almost the same as the standard python 're.escape' function, but + # escapes _only_ those metachars with meaning to the rust regex crate + re_rust_metachars = r"\\?()|\[\]{}^$#&~.+*-" + return re.sub(f"([{re_rust_metachars}])", r"\\\1", s) diff --git a/py-polars/polars/utils/_wrap.py b/py-polars/polars/_utils/wrap.py similarity index 100% rename from py-polars/polars/utils/_wrap.py rename to py-polars/polars/_utils/wrap.py diff --git a/py-polars/polars/api.py b/py-polars/polars/api.py index 97c45350d3368..af77ef41de181 100644 --- a/py-polars/polars/api.py +++ b/py-polars/polars/api.py @@ -6,7 +6,7 @@ from warnings import warn import polars._reexport as pl -from polars.utils.various import find_stacklevel +from polars._utils.various import find_stacklevel if TYPE_CHECKING: from polars import DataFrame, Expr, LazyFrame, Series diff --git a/py-polars/polars/config.py b/py-polars/polars/config.py index a481597f3ae0f..0aabe4d59a6a5 100644 --- a/py-polars/polars/config.py +++ b/py-polars/polars/config.py @@ -6,9 +6,9 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Literal, get_args +from polars._utils.deprecation import deprecate_nonkeyword_arguments +from polars._utils.various import normalize_filepath from polars.dependencies import json -from polars.utils.deprecation import deprecate_nonkeyword_arguments -from polars.utils.various import normalize_filepath if sys.version_info >= (3, 10): from typing import TypeAlias diff --git a/py-polars/polars/convert.py b/py-polars/polars/convert.py index 10c1b25da5b01..e9a390095cef9 100644 --- a/py-polars/polars/convert.py +++ b/py-polars/polars/convert.py @@ -7,12 +7,12 @@ import polars._reexport as pl from polars import functions as F +from polars._utils.various import _cast_repr_strings_with_schema from polars.datatypes import N_INFER_DEFAULT, Categorical, List, Object, String, Struct from polars.dependencies import pandas as pd from polars.dependencies import pyarrow as pa from polars.exceptions import NoDataError from polars.io import read_csv -from polars.utils.various import _cast_repr_strings_with_schema if TYPE_CHECKING: from polars import DataFrame, Series @@ -588,9 +588,12 @@ def from_arrow( 3 ] """ # noqa: W505 - if isinstance(data, pa.Table): + if isinstance(data, (pa.Table, pa.RecordBatch)): return pl.DataFrame._from_arrow( - data=data, rechunk=rechunk, schema=schema, schema_overrides=schema_overrides + data=data, + rechunk=rechunk, + schema=schema, + schema_overrides=schema_overrides, ) elif isinstance(data, (pa.Array, pa.ChunkedArray)): name = getattr(data, "_name", "") or "" @@ -606,8 +609,6 @@ def from_arrow( schema_overrides=schema_overrides, ) - if isinstance(data, pa.RecordBatch): - data = [data] if isinstance(data, Iterable): return pl.DataFrame._from_arrow( data=pa.Table.from_batches( @@ -632,8 +633,7 @@ def from_pandas( rechunk: bool = ..., nan_to_null: bool = ..., include_index: bool = ..., -) -> DataFrame: - ... +) -> DataFrame: ... @overload @@ -644,8 +644,7 @@ def from_pandas( rechunk: bool = ..., nan_to_null: bool = ..., include_index: bool = ..., -) -> Series: - ... +) -> Series: ... def from_pandas( @@ -657,7 +656,7 @@ def from_pandas( include_index: bool = False, ) -> DataFrame | Series: """ - Construct a Polars DataFrame or Series from a pandas DataFrame or Series. + Construct a Polars DataFrame or Series from a pandas DataFrame, Series, or Index. This operation clones data. @@ -676,6 +675,12 @@ def from_pandas( include_index : bool, default False Load any non-default pandas indexes as columns. + .. note:: + If the input is a pandas ``Series`` or ``DataFrame`` and has a nameless + index which just enumerates the rows, then it will not be included in the + result, regardless of this parameter. If you want to be sure to include it, + please call ``.reset_index()`` prior to calling this function. + Returns ------- DataFrame @@ -698,7 +703,7 @@ def from_pandas( │ 4 ┆ 5 ┆ 6 │ └─────┴─────┴─────┘ - Constructing a Series from a :class:`pd.Series`: + Constructing a Series from a :class:`pandas.Series`: >>> import pandas as pd >>> pd_series = pd.Series([1, 2, 3], name="pd") diff --git a/py-polars/polars/dataframe/_html.py b/py-polars/polars/dataframe/_html.py index 99f52ff94dc38..38a77cec60a81 100644 --- a/py-polars/polars/dataframe/_html.py +++ b/py-polars/polars/dataframe/_html.py @@ -1,4 +1,5 @@ """Module for formatting output data in HTML.""" + from __future__ import annotations import os @@ -58,15 +59,16 @@ def __init__( self.elements: list[str] = [] self.max_cols = max_cols self.max_rows = max_rows - self.series = from_series + self.from_series = from_series self.row_idx: Iterable[int] self.col_idx: Iterable[int] if max_rows < df.height: + half, rest = divmod(max_rows, 2) self.row_idx = [ - *list(range(max_rows // 2)), + *list(range(half + rest)), -1, - *list(range(df.height - max_rows // 2, df.height)), + *list(range(df.height - half, df.height)), ] else: self.row_idx = range(df.height) @@ -132,7 +134,7 @@ def render(self) -> list[str]: ): # format frame/series shape with '_' thousand-separators s = self.df.shape - shape = f"({s[0]:_},)" if self.series else f"({s[0]:_}, {s[1]:_})" + shape = f"({s[0]:_},)" if self.from_series else f"({s[0]:_}, {s[1]:_})" self.elements.append(f"shape: {shape}") diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 99b38aaabf510..6a51f0466a470 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -1,4 +1,5 @@ """Module containing logic related to eager DataFrames.""" + from __future__ import annotations import contextlib @@ -13,7 +14,6 @@ IO, TYPE_CHECKING, Any, - BinaryIO, Callable, ClassVar, Collection, @@ -32,6 +32,43 @@ import polars._reexport as pl from polars import functions as F +from polars._utils.construction import ( + arrow_to_pydf, + dataframe_to_pydf, + dict_to_pydf, + iterable_to_pydf, + numpy_to_idxs, + numpy_to_pydf, + pandas_to_pydf, + sequence_to_pydf, + series_to_pydf, +) +from polars._utils.convert import parse_as_duration_string +from polars._utils.deprecation import ( + deprecate_function, + deprecate_nonkeyword_arguments, + deprecate_parameter_as_positional, + deprecate_renamed_function, + deprecate_renamed_parameter, + deprecate_saturating, + issue_deprecation_warning, +) +from polars._utils.parse_expr_input import parse_as_expression +from polars._utils.unstable import issue_unstable_warning, unstable +from polars._utils.various import ( + _prepare_row_index_args, + _process_null_values, + handle_projection_columns, + is_bool_sequence, + is_int_sequence, + is_str_sequence, + normalize_filepath, + parse_version, + range_to_slice, + scale_bytes, + warn_null_comparison, +) +from polars._utils.wrap import wrap_expr, wrap_ldf, wrap_s from polars.dataframe._html import NotebookFormatter from polars.dataframe.group_by import DynamicGroupBy, GroupBy, RollingGroupBy from polars.datatypes import ( @@ -77,43 +114,6 @@ from polars.selectors import _expand_selector_dicts, _expand_selectors from polars.slice import PolarsSlice from polars.type_aliases import DbWriteMode -from polars.utils._construction import ( - arrow_to_pydf, - dict_to_pydf, - frame_to_pydf, - iterable_to_pydf, - numpy_to_idxs, - numpy_to_pydf, - pandas_to_pydf, - sequence_to_pydf, - series_to_pydf, -) -from polars.utils._parse_expr_input import parse_as_expression -from polars.utils._wrap import wrap_expr, wrap_ldf, wrap_s -from polars.utils.convert import _timedelta_to_pl_duration -from polars.utils.deprecation import ( - deprecate_function, - deprecate_nonkeyword_arguments, - deprecate_parameter_as_positional, - deprecate_renamed_function, - deprecate_renamed_parameter, - deprecate_saturating, - issue_deprecation_warning, -) -from polars.utils.unstable import issue_unstable_warning, unstable -from polars.utils.various import ( - _prepare_row_index_args, - _process_null_values, - handle_projection_columns, - is_bool_sequence, - is_int_sequence, - is_str_sequence, - normalize_filepath, - parse_version, - range_to_slice, - scale_bytes, - warn_null_comparison, -) with contextlib.suppress(ImportError): # Module not available when building docs from polars.polars import PyDataFrame @@ -205,7 +205,8 @@ class DataFrame: Two-dimensional data in various forms; dict input must contain Sequences, Generators, or a `range`. Sequence may contain Series or other Sequences. schema : Sequence of str, (str,DataType) pairs, or a {str:DataType,} dict - The DataFrame schema may be declared in several ways: + The schema of the resulting DataFrame. The schema may be declared in several + ways: * As a dict of {name:type} pairs; if type is None, it will be auto-inferred. * As a list of column names; in this case types are automatically inferred. @@ -214,6 +215,8 @@ class DataFrame: If you supply a list of column names that does not match the names in the underlying data, the names given here will overwrite them. The number of names given in the schema should match the underlying data dimensions. + + If set to `None` (default), the schema is inferred from the data. schema_overrides : dict, default None Support type specification or override of one or more columns; note that any dtypes inferred from the schema param will be overridden. @@ -221,19 +224,29 @@ class DataFrame: The number of entries in the schema should match the underlying data dimensions, unless a sequence of dictionaries is being passed, in which case a *partial* schema can be declared to prevent specific fields from being loaded. + strict : bool, default True + Throw an error if any `data` value does not exactly match the given or inferred + data type for that column. If set to `False`, values that do not match the data + type are cast to that data type or, if casting is not possible, set to null + instead. orient : {'col', 'row'}, default None Whether to interpret two-dimensional data as columns or as rows. If None, the orientation is inferred by matching the columns and data dimensions. If this does not yield conclusive results, column orientation is used. infer_schema_length : int or None - The maximum number of rows to scan for schema inference. - If set to `None`, the full data may be scanned *(this is slow)*. - This parameter only applies if the input data is a sequence or generator of - rows; other input is read as-is. + The maximum number of rows to scan for schema inference. If set to `None`, the + full data may be scanned *(this can be slow)*. This parameter only applies if + the input data is a sequence or generator of rows; other input is read as-is. nan_to_null : bool, default False If the data comes from one or more numpy arrays, can optionally convert input data np.nan values to null instead. This is a no-op for all other input data. + Notes + ----- + Polars explicitly does not support subclassing of its core data types. See + the following GitHub issue for possible workarounds: + https://github.com/pola-rs/polars/issues/2846#issuecomment-1711799869 + Examples -------- Constructing a DataFrame from a dictionary: @@ -335,19 +348,9 @@ class DataFrame: │ 1 ┆ 2 ┆ 3 │ │ 4 ┆ 5 ┆ 6 │ └─────┴─────┴─────┘ - - Notes - ----- - Some methods internally convert the DataFrame into a LazyFrame before collecting - the results back into a DataFrame. This can lead to unexpected behavior when using - a subclassed DataFrame. For example, - - >>> class MyDataFrame(pl.DataFrame): - ... pass - >>> isinstance(MyDataFrame().lazy().collect(), MyDataFrame) - False """ + _df: PyDataFrame _accessors: ClassVar[set[str]] = {"plot"} def __init__( @@ -356,6 +359,7 @@ def __init__( schema: SchemaDefinition | None = None, *, schema_overrides: SchemaDict | None = None, + strict: bool = True, orient: Orientation | None = None, infer_schema_length: int | None = N_INFER_DEFAULT, nan_to_null: bool = False, @@ -370,6 +374,7 @@ def __init__( data, schema=schema, schema_overrides=schema_overrides, + strict=strict, nan_to_null=nan_to_null, ) @@ -378,13 +383,14 @@ def __init__( data, schema=schema, schema_overrides=schema_overrides, + strict=strict, orient=orient, infer_schema_length=infer_schema_length, ) elif isinstance(data, pl.Series): self._df = series_to_pydf( - data, schema=schema, schema_overrides=schema_overrides + data, schema=schema, schema_overrides=schema_overrides, strict=strict ) elif _check_for_numpy(data) and isinstance(data, np.ndarray): @@ -392,18 +398,19 @@ def __init__( data, schema=schema, schema_overrides=schema_overrides, + strict=strict, orient=orient, nan_to_null=nan_to_null, ) elif _check_for_pyarrow(data) and isinstance(data, pa.Table): self._df = arrow_to_pydf( - data, schema=schema, schema_overrides=schema_overrides + data, schema=schema, schema_overrides=schema_overrides, strict=strict ) elif _check_for_pandas(data) and isinstance(data, pd.DataFrame): self._df = pandas_to_pydf( - data, schema=schema, schema_overrides=schema_overrides + data, schema=schema, schema_overrides=schema_overrides, strict=strict ) elif not isinstance(data, Sized) and isinstance(data, (Generator, Iterable)): @@ -411,13 +418,14 @@ def __init__( data, schema=schema, schema_overrides=schema_overrides, + strict=strict, orient=orient, infer_schema_length=infer_schema_length, ) elif isinstance(data, pl.DataFrame): - self._df = frame_to_pydf( - data, schema=schema, schema_overrides=schema_overrides + self._df = dataframe_to_pydf( + data, schema=schema, schema_overrides=schema_overrides, strict=strict ) else: msg = ( @@ -537,7 +545,7 @@ def _from_numpy( @classmethod def _from_arrow( cls, - data: pa.Table, + data: pa.Table | pa.RecordBatch, schema: SchemaDefinition | None = None, *, schema_overrides: SchemaDict | None = None, @@ -551,8 +559,8 @@ def _from_arrow( Parameters ---------- - data : arrow table, array, or sequence of sequences - Data representing an Arrow Table or Array. + data : arrow Table, RecordBatch, or sequence of sequences + Data representing an Arrow Table or RecordBatch. schema : Sequence of str, (str,DataType) pairs, or a {str:DataType,} dict The DataFrame schema may be declared in several ways: @@ -848,7 +856,7 @@ def _read_parquet( @classmethod def _read_avro( cls, - source: str | Path | BinaryIO | bytes, + source: str | Path | IO[bytes] | bytes, *, columns: Sequence[int] | Sequence[str] | None = None, n_rows: int | None = None, @@ -1537,8 +1545,7 @@ def _take_with_series(self, s: Series) -> DataFrame: return self._from_pydf(self._df.take_with_series(s._s)) @overload - def __getitem__(self, item: str) -> Series: - ... + def __getitem__(self, item: str) -> Series: ... @overload def __getitem__( @@ -1550,16 +1557,13 @@ def __getitem__( | tuple[int, MultiColSelector] | tuple[MultiRowSelector, MultiColSelector] ), - ) -> Self: - ... + ) -> Self: ... @overload - def __getitem__(self, item: tuple[int, int | str]) -> Any: - ... + def __getitem__(self, item: tuple[int, int | str]) -> Any: ... @overload - def __getitem__(self, item: tuple[MultiRowSelector, int | str]) -> Series: - ... + def __getitem__(self, item: tuple[MultiRowSelector, int | str]) -> Series: ... def __getitem__( self, @@ -1807,7 +1811,7 @@ def __deepcopy__(self, memo: None = None) -> Self: def _ipython_key_completions_(self) -> list[str]: return self.columns - def _repr_html_(self, **kwargs: Any) -> str: + def _repr_html_(self, *, _from_series: bool = False) -> str: """ Format output data in HTML for display in Jupyter Notebooks. @@ -1819,18 +1823,18 @@ def _repr_html_(self, **kwargs: Any) -> str: """ max_cols = int(os.environ.get("POLARS_FMT_MAX_COLS", default=75)) if max_cols < 0: - max_cols = self.shape[1] - max_rows = int(os.environ.get("POLARS_FMT_MAX_ROWS", default=25)) + max_cols = self.width + + max_rows = int(os.environ.get("POLARS_FMT_MAX_ROWS", default=10)) if max_rows < 0: - max_rows = self.shape[0] + max_rows = self.height - from_series = kwargs.get("from_series", False) return "".join( NotebookFormatter( self, max_cols=max_cols, max_rows=max_rows, - from_series=from_series, + from_series=_from_series, ).render() ) @@ -1917,19 +1921,16 @@ def to_arrow(self) -> pa.Table: return pa.Table.from_batches(record_batches) @overload - def to_dict(self, as_series: Literal[True] = ...) -> dict[str, Series]: - ... + def to_dict(self, as_series: Literal[True] = ...) -> dict[str, Series]: ... @overload - def to_dict(self, as_series: Literal[False]) -> dict[str, list[Any]]: - ... + def to_dict(self, as_series: Literal[False]) -> dict[str, list[Any]]: ... @overload def to_dict( self, as_series: bool, # noqa: FBT001 - ) -> dict[str, Series] | dict[str, list[Any]]: - ... + ) -> dict[str, Series] | dict[str, list[Any]]: ... @deprecate_nonkeyword_arguments(version="0.19.13") def to_dict( @@ -2048,8 +2049,9 @@ def to_numpy( structured: bool = False, # noqa: FBT001 *, order: IndexOrder = "fortran", - use_pyarrow: bool = True, + allow_copy: bool = True, writable: bool = False, + use_pyarrow: bool = True, ) -> np.ndarray[Any, Any]: """ Convert this DataFrame to a NumPy ndarray. @@ -2070,20 +2072,18 @@ def to_numpy( one-dimensional array. Note that this option only takes effect if `structured` is set to `False` and the DataFrame dtypes allow for a global dtype for all columns. - use_pyarrow - Use `pyarrow.Array.to_numpy - `_ - - function for the conversion to numpy if necessary. + allow_copy + Allow memory to be copied to perform the conversion. If set to `False`, + causes conversions that are not zero-copy to fail. writable Ensure the resulting array is writable. This will force a copy of the data if the array was created without copy, as the underlying Arrow data is immutable. + use_pyarrow + Use `pyarrow.Array.to_numpy + `_ - Notes - ----- - If you're attempting to convert String or Decimal to an array, you'll need to - install `pyarrow`. + function for the conversion to numpy if necessary. Examples -------- @@ -2117,7 +2117,15 @@ def to_numpy( rec.array([(1, 6.5, 'a'), (2, 7. , 'b'), (3, 8.5, 'c')], dtype=[('foo', 'u1'), ('bar', ' None: + if not allow_copy and not self.is_empty(): + msg = f"copy not allowed: {msg}" + raise RuntimeError(msg) + if structured: + raise_on_copy("cannot create structured array without copying data") + arrays = [] struct_dtype = [] for s in self.iter_columns(): @@ -2136,9 +2144,14 @@ def to_numpy( array = self._df.to_numpy_view() if array is not None: if writable and not array.flags.writeable: + raise_on_copy("cannot create writable array without copying data") array = array.copy() return array + raise_on_copy( + "only numeric data without nulls in Fortran-like order can be converted without copy" + ) + out = self._df.to_numpy(order) if out is None: return np.vstack( @@ -2412,8 +2425,7 @@ def write_json( *, pretty: bool = ..., row_oriented: bool = ..., - ) -> str: - ... + ) -> str: ... @overload def write_json( @@ -2422,8 +2434,7 @@ def write_json( *, pretty: bool = ..., row_oriented: bool = ..., - ) -> None: - ... + ) -> None: ... def write_json( self, @@ -2438,7 +2449,7 @@ def write_json( Parameters ---------- file - File path or writeable file-like object to which the result will be written. + File path or writable file-like object to which the result will be written. If set to `None` (default), the output is returned as a string instead. pretty Pretty serialize json. @@ -2480,12 +2491,10 @@ def write_json( return None @overload - def write_ndjson(self, file: None = None) -> str: - ... + def write_ndjson(self, file: None = None) -> str: ... @overload - def write_ndjson(self, file: IOBase | str | Path) -> None: - ... + def write_ndjson(self, file: IOBase | str | Path) -> None: ... def write_ndjson(self, file: IOBase | str | Path | None = None) -> str | None: r""" @@ -2494,7 +2503,7 @@ def write_ndjson(self, file: IOBase | str | Path | None = None) -> str | None: Parameters ---------- file - File path or writeable file-like object to which the result will be written. + File path or writable file-like object to which the result will be written. If set to `None` (default), the output is returned as a string instead. Examples @@ -2542,13 +2551,12 @@ def write_csv( float_precision: int | None = ..., null_value: str | None = ..., quote_style: CsvQuoteStyle | None = ..., - ) -> str: - ... + ) -> str: ... @overload def write_csv( self, - file: BytesIO | TextIOWrapper | str | Path, + file: str | Path | IO[str] | IO[bytes], *, include_bom: bool = ..., include_header: bool = ..., @@ -2562,14 +2570,13 @@ def write_csv( float_precision: int | None = ..., null_value: str | None = ..., quote_style: CsvQuoteStyle | None = ..., - ) -> None: - ... + ) -> None: ... @deprecate_renamed_parameter("quote", "quote_char", version="0.19.8") @deprecate_renamed_parameter("has_header", "include_header", version="0.19.13") def write_csv( self, - file: BytesIO | TextIOWrapper | str | Path | None = None, + file: str | Path | IO[str] | IO[bytes] | None = None, *, include_bom: bool = False, include_header: bool = True, @@ -2590,7 +2597,7 @@ def write_csv( Parameters ---------- file - File path or writeable file-like object to which the result will be written. + File path or writable file-like object to which the result will be written. If set to `None` (default), the output is returned as a string instead. include_bom Whether to include UTF-8 BOM in the CSV output. @@ -2691,7 +2698,7 @@ def write_csv( def write_avro( self, - file: BinaryIO | BytesIO | str | Path, + file: str | Path | IO[bytes], compression: AvroCompression = "uncompressed", name: str = "", ) -> None: @@ -2701,7 +2708,7 @@ def write_avro( Parameters ---------- file - File path or writeable file-like object to which the data will be written. + File path or writable file-like object to which the data will be written. compression : {'uncompressed', 'snappy', 'deflate'} Compression method. Defaults to "uncompressed". name @@ -2733,7 +2740,7 @@ def write_avro( @deprecate_renamed_parameter("has_header", "include_header", version="0.19.13") def write_excel( self, - workbook: Workbook | BytesIO | Path | str | None = None, + workbook: Workbook | IO[bytes] | Path | str | None = None, worksheet: str | None = None, *, position: tuple[int, int] | str = "A1", @@ -3241,22 +3248,20 @@ def write_ipc( compression: IpcCompression = "uncompressed", *, future: bool = False, - ) -> BytesIO: - ... + ) -> BytesIO: ... @overload def write_ipc( self, - file: BinaryIO | BytesIO | str | Path, + file: str | Path | IO[bytes], compression: IpcCompression = "uncompressed", *, future: bool = False, - ) -> None: - ... + ) -> None: ... def write_ipc( self, - file: BinaryIO | BytesIO | str | Path | None, + file: str | Path | IO[bytes] | None, compression: IpcCompression = "uncompressed", *, future: bool = False, @@ -3269,7 +3274,7 @@ def write_ipc( Parameters ---------- file - Path or writeable file-like object to which the IPC data will be + Path or writable file-like object to which the IPC data will be written. If set to `None`, the output is returned as a BytesIO object. compression : {'uncompressed', 'lz4', 'zstd'} Compression method. Defaults to "uncompressed". @@ -3317,20 +3322,18 @@ def write_ipc_stream( self, file: None, compression: IpcCompression = "uncompressed", - ) -> BytesIO: - ... + ) -> BytesIO: ... @overload def write_ipc_stream( self, - file: BinaryIO | BytesIO | str | Path, + file: str | Path | IO[bytes], compression: IpcCompression = "uncompressed", - ) -> None: - ... + ) -> None: ... def write_ipc_stream( self, - file: BinaryIO | BytesIO | str | Path | None, + file: str | Path | IO[bytes] | None, compression: IpcCompression = "uncompressed", ) -> BytesIO | None: """ @@ -3341,7 +3344,7 @@ def write_ipc_stream( Parameters ---------- file - Path or writeable file-like object to which the IPC record batch data will + Path or writable file-like object to which the IPC record batch data will be written. If set to `None`, the output is returned as a BytesIO object. compression : {'uncompressed', 'lz4', 'zstd'} Compression method. Defaults to "uncompressed". @@ -3390,7 +3393,7 @@ def write_parquet( Parameters ---------- file - File path or writeable file-like object to which the result will be written. + File path or writable file-like object to which the result will be written. compression : {'lz4', 'uncompressed', 'snappy', 'gzip', 'lzo', 'brotli', 'zstd'} Choose "zstd" for good compression performance. Choose "lz4" for fast compression/decompression. @@ -3669,11 +3672,10 @@ def write_delta( target: str | Path | deltalake.DeltaTable, *, mode: Literal["error", "append", "overwrite", "ignore"] = ..., - overwrite_schema: bool = ..., + overwrite_schema: bool | None = ..., storage_options: dict[str, str] | None = ..., delta_write_options: dict[str, Any] | None = ..., - ) -> None: - ... + ) -> None: ... @overload def write_delta( @@ -3681,18 +3683,17 @@ def write_delta( target: str | Path | deltalake.DeltaTable, *, mode: Literal["merge"], - overwrite_schema: bool = ..., + overwrite_schema: bool | None = ..., storage_options: dict[str, str] | None = ..., delta_merge_options: dict[str, Any], - ) -> deltalake.table.TableMerger: - ... + ) -> deltalake.table.TableMerger: ... def write_delta( self, target: str | Path | deltalake.DeltaTable, *, mode: Literal["error", "append", "overwrite", "ignore", "merge"] = "error", - overwrite_schema: bool = False, + overwrite_schema: bool | None = None, storage_options: dict[str, str] | None = None, delta_write_options: dict[str, Any] | None = None, delta_merge_options: dict[str, Any] | None = None, @@ -3715,6 +3716,10 @@ def write_delta( with the existing data. overwrite_schema If True, allows updating the schema of the table. + + .. deprecated:: 0.20.14 + Use the parameter `delta_write_options` instead and pass + `{"schema_mode": "overwrite"}`. storage_options Extra options for the storage backends supported by `deltalake`. For cloud storages, this may include configurations for authentication etc. @@ -3770,12 +3775,14 @@ def write_delta( >>> df.write_delta(table_path, mode="append") # doctest: +SKIP Overwrite a Delta Lake table as a new version. - If the schemas of the new and old data are the same, setting - `overwrite_schema` is not required. + If the schemas of the new and old data are the same, specifying the + `schema_mode` is not required. >>> existing_table_path = "/path/to/delta-table/" >>> df.write_delta( - ... existing_table_path, mode="overwrite", overwrite_schema=True + ... existing_table_path, + ... mode="overwrite", + ... delta_write_options={"schema_mode": "overwrite"}, ... ) # doctest: +SKIP Write a DataFrame as a Delta Lake table to a cloud object store like S3. @@ -3805,9 +3812,6 @@ def write_delta( For all `TableMerger` methods, check the deltalake docs `here `__. - Schema evolution is not yet supported in by the `deltalake` package, therefore - `overwrite_schema` will not have any effect on a merge operation. - >>> df = pl.DataFrame( ... { ... "foo": [1, 2, 3, 4, 5], @@ -3831,6 +3835,13 @@ def write_delta( ... .execute() ... ) # doctest: +SKIP """ + if overwrite_schema is not None: + issue_deprecation_warning( + "The parameter `overwrite_schema` for `write_delta` is deprecated." + ' Use the parameter `delta_write_options` instead and pass `{"schema_mode": "overwrite"}`.', + version="0.20.14", + ) + from polars.io.delta import ( _check_for_unsupported_types, _check_if_delta_available, @@ -3863,13 +3874,15 @@ def write_delta( if delta_write_options is None: delta_write_options = {} + if overwrite_schema: + delta_write_options["schema_mode"] = "overwrite" + schema = delta_write_options.pop("schema", None) write_deltalake( table_or_uri=target, data=data, schema=schema, mode=mode, - overwrite_schema=overwrite_schema, storage_options=storage_options, large_dtypes=True, **delta_write_options, @@ -3909,9 +3922,9 @@ def estimated_size(self, unit: SizeUnit = "b") -> int | float: ... schema=[("x", pl.UInt32), ("y", pl.Float64), ("z", pl.String)], ... ) >>> df.estimated_size() - 28000000 + 17888890 >>> df.estimated_size("mb") - 26.702880859375 + 17.0601749420166 """ sz = self._df.estimated_size() return scale_bytes(sz, unit) @@ -4263,8 +4276,7 @@ def glimpse( max_items_per_column: int = ..., max_colname_length: int = ..., return_as_string: Literal[False] = ..., - ) -> None: - ... + ) -> None: ... @overload def glimpse( @@ -4273,8 +4285,7 @@ def glimpse( max_items_per_column: int = ..., max_colname_length: int = ..., return_as_string: Literal[True], - ) -> str: - ... + ) -> str: ... @overload def glimpse( @@ -4283,8 +4294,7 @@ def glimpse( max_items_per_column: int = ..., max_colname_length: int = ..., return_as_string: bool, - ) -> str | None: - ... + ) -> str | None: ... def glimpse( self, @@ -4444,10 +4454,11 @@ def describe( Customize which percentiles are displayed, applying linear interpolation: - >>> df.describe( - ... percentiles=[0.1, 0.3, 0.5, 0.7, 0.9], - ... interpolation="linear", - ... ) + >>> with pl.Config(tbl_rows=12): + ... df.describe( + ... percentiles=[0.1, 0.3, 0.5, 0.7, 0.9], + ... interpolation="linear", + ... ) shape: (11, 7) ┌────────────┬──────────┬──────────┬──────────┬──────┬────────────┬──────────┐ │ statistic ┆ float ┆ int ┆ bool ┆ str ┆ date ┆ time │ @@ -5482,7 +5493,7 @@ def rolling( check_sorted: bool = True, ) -> RollingGroupBy: """ - Create rolling groups based on a time, Int32, or Int64 column. + Create rolling groups based on a temporal or integer column. Different from a `group_by_dynamic` the windows are now determined by the individual values and are not of constant intervals. For constant intervals use @@ -5526,11 +5537,6 @@ def rolling( not be 24 hours, due to daylight savings). Similarly for "calendar week", "calendar month", "calendar quarter", and "calendar year". - In case of a rolling operation on an integer column, the windows are defined by: - - - **"1i" # length 1** - - **"10i" # length 10** - Parameters ---------- index_column @@ -5540,8 +5546,8 @@ def rolling( then it must be sorted in ascending order within each group). In case of a rolling operation on indices, dtype needs to be one of - {Int32, Int64}. Note that Int32 gets temporarily cast to Int64, so if - performance matters use an Int64 column. + {UInt32, UInt64, Int32, Int64}. Note that the first three get temporarily + cast to Int64, so if performance matters use an Int64 column. period length of the window - must be non-negative offset @@ -6052,8 +6058,8 @@ def upsample( if offset is None: offset = "0ns" - every = _timedelta_to_pl_duration(every) - offset = _timedelta_to_pl_duration(offset) + every = parse_as_duration_string(every) + offset = parse_as_duration_string(offset) return self._from_pydf( self._df.upsample(by, time_column, every, offset, maintain_order) @@ -6154,41 +6160,111 @@ def join_asof( Examples -------- - >>> from datetime import datetime + >>> from datetime import date >>> gdp = pl.DataFrame( ... { - ... "date": [ - ... datetime(2016, 1, 1), - ... datetime(2017, 1, 1), - ... datetime(2018, 1, 1), - ... datetime(2019, 1, 1), - ... ], # note record date: Jan 1st (sorted!) - ... "gdp": [4164, 4411, 4566, 4696], + ... "date": pl.date_range( + ... date(2016, 1, 1), + ... date(2020, 1, 1), + ... "1y", + ... eager=True, + ... ), + ... "gdp": [4164, 4411, 4566, 4696, 4827], ... } - ... ).set_sorted("date") + ... ) + >>> gdp + shape: (5, 2) + ┌────────────┬──────┐ + │ date ┆ gdp │ + │ --- ┆ --- │ + │ date ┆ i64 │ + ╞════════════╪══════╡ + │ 2016-01-01 ┆ 4164 │ + │ 2017-01-01 ┆ 4411 │ + │ 2018-01-01 ┆ 4566 │ + │ 2019-01-01 ┆ 4696 │ + │ 2020-01-01 ┆ 4827 │ + └────────────┴──────┘ + >>> population = pl.DataFrame( ... { - ... "date": [ - ... datetime(2016, 5, 12), - ... datetime(2017, 5, 12), - ... datetime(2018, 5, 12), - ... datetime(2019, 5, 12), - ... ], # note record date: May 12th (sorted!) - ... "population": [82.19, 82.66, 83.12, 83.52], + ... "date": [date(2016, 3, 1), date(2018, 8, 1), date(2019, 1, 1)], + ... "population": [82.19, 82.66, 83.12], ... } - ... ).set_sorted("date") + ... ).sort("date") + >>> population + shape: (3, 2) + ┌────────────┬────────────┐ + │ date ┆ population │ + │ --- ┆ --- │ + │ date ┆ f64 │ + ╞════════════╪════════════╡ + │ 2016-03-01 ┆ 82.19 │ + │ 2018-08-01 ┆ 82.66 │ + │ 2019-01-01 ┆ 83.12 │ + └────────────┴────────────┘ + + Note how the dates don't quite match. If we join them using `join_asof` and + `strategy='backward'`, then each date from `population` which doesn't have an + exact match is matched with the closest earlier date from `gdp`: + >>> population.join_asof(gdp, on="date", strategy="backward") - shape: (4, 3) - ┌─────────────────────┬────────────┬──────┐ - │ date ┆ population ┆ gdp │ - │ --- ┆ --- ┆ --- │ - │ datetime[μs] ┆ f64 ┆ i64 │ - ╞═════════════════════╪════════════╪══════╡ - │ 2016-05-12 00:00:00 ┆ 82.19 ┆ 4164 │ - │ 2017-05-12 00:00:00 ┆ 82.66 ┆ 4411 │ - │ 2018-05-12 00:00:00 ┆ 83.12 ┆ 4566 │ - │ 2019-05-12 00:00:00 ┆ 83.52 ┆ 4696 │ - └─────────────────────┴────────────┴──────┘ + shape: (3, 3) + ┌────────────┬────────────┬──────┐ + │ date ┆ population ┆ gdp │ + │ --- ┆ --- ┆ --- │ + │ date ┆ f64 ┆ i64 │ + ╞════════════╪════════════╪══════╡ + │ 2016-03-01 ┆ 82.19 ┆ 4164 │ + │ 2018-08-01 ┆ 82.66 ┆ 4566 │ + │ 2019-01-01 ┆ 83.12 ┆ 4696 │ + └────────────┴────────────┴──────┘ + + Note how: + + - date `2016-03-01` from `population` is matched with `2016-01-01` from `gdp`; + - date `2018-08-01` from `population` is matched with `2018-01-01` from `gdp`. + + If we instead use `strategy='forward'`, then each date from `population` which + doesn't have an exact match is matched with the closest later date from `gdp`: + + >>> population.join_asof(gdp, on="date", strategy="forward") + shape: (3, 3) + ┌────────────┬────────────┬──────┐ + │ date ┆ population ┆ gdp │ + │ --- ┆ --- ┆ --- │ + │ date ┆ f64 ┆ i64 │ + ╞════════════╪════════════╪══════╡ + │ 2016-03-01 ┆ 82.19 ┆ 4411 │ + │ 2018-08-01 ┆ 82.66 ┆ 4696 │ + │ 2019-01-01 ┆ 83.12 ┆ 4696 │ + └────────────┴────────────┴──────┘ + + Note how: + + - date `2016-03-01` from `population` is matched with `2017-01-01` from `gdp`; + - date `2018-08-01` from `population` is matched with `2019-01-01` from `gdp`. + + Finally, `strategy='nearest'` gives us a mix of the two results above, as each + date from `population` which doesn't have an exact match is matched with the + closest date from `gdp`, regardless of whether it's earlier or later: + + >>> population.join_asof(gdp, on="date", strategy="nearest") + shape: (3, 3) + ┌────────────┬────────────┬──────┐ + │ date ┆ population ┆ gdp │ + │ --- ┆ --- ┆ --- │ + │ date ┆ f64 ┆ i64 │ + ╞════════════╪════════════╪══════╡ + │ 2016-03-01 ┆ 82.19 ┆ 4164 │ + │ 2018-08-01 ┆ 82.66 ┆ 4696 │ + │ 2019-01-01 ┆ 83.12 ┆ 4696 │ + └────────────┴────────────┴──────┘ + + Note how: + + - date `2016-03-01` from `population` is matched with `2016-01-01` from `gdp`; + - date `2018-08-01` from `population` is matched with `2019-01-01` from `gdp`. """ tolerance = deprecate_saturating(tolerance) if not isinstance(other, DataFrame): @@ -6260,7 +6336,7 @@ def join( * *outer_coalesce* Same as 'outer', but coalesces the key columns * *cross* - Returns the cartisian product of rows from both tables + Returns the Cartesian product of rows from both tables * *semi* Filter rows that have a match in the right table. * *anti* @@ -6489,7 +6565,7 @@ def map_rows( >>> df.select(pl.col("foo") * 2 + pl.col("bar")) # doctest: +IGNORE_RESULT """ # TODO: Enable warning for inefficient map - # from polars.utils.udfs import warn_on_inefficient_map + # from polars._utils.udfs import warn_on_inefficient_map # warn_on_inefficient_map(function, columns=self.columns, map_target="frame) out, is_df = self._df.map_rows(function, return_dtype, inference_size) @@ -7270,7 +7346,8 @@ def pivot( ---------- values Column values to aggregate. Can be multiple columns if the *columns* - arguments contains multiple columns as well. + arguments contains multiple columns as well. If None, all remaining columns + will be used. index One or multiple keys to group by. columns @@ -7386,9 +7463,10 @@ def pivot( │ b ┆ 0.964028 ┆ 0.999954 │ └──────┴──────────┴──────────┘ """ # noqa: W505 - values = _expand_selectors(self, values) index = _expand_selectors(self, index) columns = _expand_selectors(self, columns) + if values is not None: + values = _expand_selectors(self, values) if isinstance(aggregate_function, str): if aggregate_function == "first": @@ -7424,9 +7502,9 @@ def pivot( return self._from_pydf( self._df.pivot_expr( - values, index, columns, + values, maintain_order, sort_columns, aggregate_expr, @@ -7645,8 +7723,7 @@ def partition_by( maintain_order: bool = ..., include_key: bool = ..., as_dict: Literal[False] = ..., - ) -> list[Self]: - ... + ) -> list[Self]: ... @overload def partition_by( @@ -7656,8 +7733,7 @@ def partition_by( maintain_order: bool = ..., include_key: bool = ..., as_dict: Literal[True], - ) -> dict[Any, Self]: - ... + ) -> dict[Any, Self]: ... def partition_by( self, @@ -8324,12 +8400,10 @@ def with_columns_seq( return self.lazy().with_columns_seq(*exprs, **named_exprs).collect(_eager=True) @overload - def n_chunks(self, strategy: Literal["first"] = ...) -> int: - ... + def n_chunks(self, strategy: Literal["first"] = ...) -> int: ... @overload - def n_chunks(self, strategy: Literal["all"]) -> list[int]: - ... + def n_chunks(self, strategy: Literal["all"]) -> list[int]: ... def n_chunks(self, strategy: str = "first") -> int | list[int]: """ @@ -8368,16 +8442,13 @@ def n_chunks(self, strategy: str = "first") -> int | list[int]: raise ValueError(msg) @overload - def max(self, axis: Literal[0] = ...) -> Self: - ... + def max(self, axis: Literal[0] = ...) -> Self: ... @overload - def max(self, axis: Literal[1]) -> Series: - ... + def max(self, axis: Literal[1]) -> Series: ... @overload - def max(self, axis: int = 0) -> Self | Series: - ... + def max(self, axis: int = 0) -> Self | Series: ... def max(self, axis: int | None = None) -> Self | Series: """ @@ -8457,16 +8528,13 @@ def max_horizontal(self) -> Series: return self.select(max=F.max_horizontal(F.all())).to_series() @overload - def min(self, axis: Literal[0] | None = ...) -> Self: - ... + def min(self, axis: Literal[0] | None = ...) -> Self: ... @overload - def min(self, axis: Literal[1]) -> Series: - ... + def min(self, axis: Literal[1]) -> Series: ... @overload - def min(self, axis: int) -> Self | Series: - ... + def min(self, axis: int) -> Self | Series: ... def min(self, axis: int | None = None) -> Self | Series: """ @@ -8551,8 +8619,7 @@ def sum( *, axis: Literal[0] = ..., null_strategy: NullStrategy = "ignore", - ) -> Self: - ... + ) -> Self: ... @overload def sum( @@ -8560,8 +8627,7 @@ def sum( *, axis: Literal[1], null_strategy: NullStrategy = "ignore", - ) -> Series: - ... + ) -> Series: ... @overload def sum( @@ -8569,8 +8635,7 @@ def sum( *, axis: int, null_strategy: NullStrategy = "ignore", - ) -> Self | Series: - ... + ) -> Self | Series: ... def sum( self, @@ -8678,8 +8743,7 @@ def mean( *, axis: Literal[0] = ..., null_strategy: NullStrategy = "ignore", - ) -> Self: - ... + ) -> Self: ... @overload def mean( @@ -8687,8 +8751,7 @@ def mean( *, axis: Literal[1], null_strategy: NullStrategy = "ignore", - ) -> Series: - ... + ) -> Series: ... @overload def mean( @@ -8696,8 +8759,7 @@ def mean( *, axis: int, null_strategy: NullStrategy = "ignore", - ) -> Self | Series: - ... + ) -> Self | Series: ... def mean( self, @@ -9206,10 +9268,16 @@ def n_unique(self, subset: str | Expr | Sequence[str | Expr] | None = None) -> i df = self.lazy().select(expr.n_unique()).collect(_eager=True) return 0 if df.is_empty() else df.row(0)[0] + @deprecate_function( + "Use `select(pl.all().approx_n_unique())` instead.", version="0.20.11" + ) def approx_n_unique(self) -> DataFrame: """ Approximate count of unique values. + .. deprecated:: 0.20.11 + Use `select(pl.all().approx_n_unique())` instead. + This is done using the HyperLogLog++ algorithm for cardinality estimation. Examples @@ -9220,7 +9288,7 @@ def approx_n_unique(self) -> DataFrame: ... "b": [1, 2, 1, 1], ... } ... ) - >>> df.approx_n_unique() + >>> df.approx_n_unique() # doctest: +SKIP shape: (1, 2) ┌─────┬─────┐ │ a ┆ b │ @@ -9438,8 +9506,7 @@ def row( *, by_predicate: Expr | None = ..., named: Literal[False] = ..., - ) -> tuple[Any, ...]: - ... + ) -> tuple[Any, ...]: ... @overload def row( @@ -9448,8 +9515,7 @@ def row( *, by_predicate: Expr | None = ..., named: Literal[True], - ) -> dict[str, Any]: - ... + ) -> dict[str, Any]: ... def row( self, @@ -9558,12 +9624,10 @@ def row( raise ValueError(msg) @overload - def rows(self, *, named: Literal[False] = ...) -> list[tuple[Any, ...]]: - ... + def rows(self, *, named: Literal[False] = ...) -> list[tuple[Any, ...]]: ... @overload - def rows(self, *, named: Literal[True]) -> list[dict[str, Any]]: - ... + def rows(self, *, named: Literal[True]) -> list[dict[str, Any]]: ... def rows( self, *, named: bool = False @@ -9799,14 +9863,12 @@ def rows_by_key( @overload def iter_rows( self, *, named: Literal[False] = ..., buffer_size: int = ... - ) -> Iterator[tuple[Any, ...]]: - ... + ) -> Iterator[tuple[Any, ...]]: ... @overload def iter_rows( self, *, named: Literal[True], buffer_size: int = ... - ) -> Iterator[dict[str, Any]]: - ... + ) -> Iterator[dict[str, Any]]: ... def iter_rows( self, *, named: bool = False, buffer_size: int = 512 diff --git a/py-polars/polars/dataframe/group_by.py b/py-polars/polars/dataframe/group_by.py index fd89b8256bd1d..719afd9c8ebd2 100644 --- a/py-polars/polars/dataframe/group_by.py +++ b/py-polars/polars/dataframe/group_by.py @@ -3,8 +3,8 @@ from typing import TYPE_CHECKING, Callable, Iterable, Iterator from polars import functions as F -from polars.utils.convert import _timedelta_to_pl_duration -from polars.utils.deprecation import ( +from polars._utils.convert import parse_as_duration_string +from polars._utils.deprecation import ( deprecate_renamed_function, issue_deprecation_warning, ) @@ -478,6 +478,9 @@ def count(self) -> DataFrame: """ Return the number of rows in each group. + .. deprecated:: 0.20.5 + This method has been renamed to :func:`GroupBy.len`. + Rows containing null values count towards the total. Examples @@ -792,8 +795,8 @@ def __init__( by: IntoExpr | Iterable[IntoExpr] | None, check_sorted: bool, ): - period = _timedelta_to_pl_duration(period) - offset = _timedelta_to_pl_duration(offset) + period = parse_as_duration_string(period) + offset = parse_as_duration_string(offset) self.df = df self.time_column = index_column @@ -969,9 +972,9 @@ def __init__( start_by: StartBy, check_sorted: bool, ): - every = _timedelta_to_pl_duration(every) - period = _timedelta_to_pl_duration(period) - offset = _timedelta_to_pl_duration(offset) + every = parse_as_duration_string(every) + period = parse_as_duration_string(period) + offset = parse_as_duration_string(offset) self.df = df self.time_column = index_column diff --git a/py-polars/polars/datatypes/classes.py b/py-polars/polars/datatypes/classes.py index d9e2fde8614c9..c6bfacbdafea0 100644 --- a/py-polars/polars/datatypes/classes.py +++ b/py-polars/polars/datatypes/classes.py @@ -90,9 +90,6 @@ def is_nested(self) -> bool: # noqa: D102 class DataType(metaclass=DataTypeClass): """Base class for all Polars data types.""" - def __reduce__(self) -> Any: - return (_custom_reconstruct, (type(self), object, None), self.__dict__) - def _string_repr(self) -> str: return _dtype_str_repr(self) @@ -169,7 +166,7 @@ def is_not(self, other: PolarsDataType) -> bool: >>> pl.List.is_not(pl.List(pl.Int32)) # doctest: +SKIP True """ - from polars.utils.deprecation import issue_deprecation_warning + from polars._utils.deprecation import issue_deprecation_warning issue_deprecation_warning( "`DataType.is_not` is deprecated and will be removed in the next breaking release." @@ -219,19 +216,6 @@ def is_nested(cls) -> bool: return issubclass(cls, NestedType) -def _custom_reconstruct( - cls: type[Any], base: type[Any], state: Any -) -> PolarsDataType | type: - """Helper function for unpickling DataType objects.""" - if state: - obj = base.__new__(cls, state) - if base.__init__ != object.__init__: - base.__init__(obj, state) - else: - obj = object.__new__(cls) - return obj - - class DataTypeGroup(frozenset): # type: ignore[type-arg] """Group of data types.""" @@ -340,6 +324,14 @@ class Decimal(NumericType): This functionality is considered **unstable**. It is a work-in-progress feature and may not always work as expected. It may be changed at any point without it being considered a breaking change. + + Parameters + ---------- + precision + Maximum number of digits in each number. + If set to `None` (default), the precision is inferred. + scale + Number of digits to the right of the decimal point in each number. """ precision: int | None @@ -352,7 +344,7 @@ def __init__( ): # Issuing the warning on `__init__` does not trigger when the class is used # without being instantiated, but it's better than nothing - from polars.utils.unstable import issue_unstable_warning + from polars._utils.unstable import issue_unstable_warning issue_unstable_warning( "The Decimal data type is considered unstable." @@ -397,48 +389,80 @@ class Binary(DataType): class Date(TemporalType): - """Calendar date type.""" + """ + Data type representing a calendar date. + + Notes + ----- + The underlying representation of this type is a 32-bit signed integer. + The integer indicates the number of days since the Unix epoch (1970-01-01). + The number can be negative to indicate dates before the epoch. + """ class Time(TemporalType): - """Time of day type.""" + """ + Data type representing the time of day. + + Notes + ----- + The underlying representation of this type is a 64-bit signed integer. + The integer indicates the number of nanoseconds since midnight. + """ class Datetime(TemporalType): - """Calendar date and time type.""" + """ + Data type representing a calendar date and time of day. + + Parameters + ---------- + time_unit : {'us', 'ns', 'ms'} + Unit of time. Defaults to `'us'` (microseconds). + time_zone + Time zone string, as defined in zoneinfo (to see valid strings run + `import zoneinfo; zoneinfo.available_timezones()` for a full list). + When using to match dtypes, can use "*" to check for Datetime columns + that have any timezone. + + Notes + ----- + The underlying representation of this type is a 64-bit signed integer. + The integer indicates the number of time units since the Unix epoch + (1970-01-01 00:00:00). The number can be negative to indicate datetimes before the + epoch. + """ time_unit: TimeUnit | None = None time_zone: str | None = None def __init__( - self, time_unit: TimeUnit | None = "us", time_zone: str | timezone | None = None + self, time_unit: TimeUnit = "us", time_zone: str | timezone | None = None ): - """ - Calendar date and time type. - - Parameters - ---------- - time_unit : {'us', 'ns', 'ms'} - Unit of time / precision. - time_zone - Time zone string, as defined in zoneinfo (to see valid strings run - `import zoneinfo; zoneinfo.available_timezones()` for a full list). - When using to match dtypes, can use "*" to check for Datetime columns - that have any timezone. - """ - if isinstance(time_zone, timezone): - time_zone = str(time_zone) - - self.time_unit = time_unit or "us" - self.time_zone = time_zone + if time_unit is None: + from polars._utils.deprecation import issue_deprecation_warning + + issue_deprecation_warning( + "Passing `time_unit=None` to the Datetime constructor is deprecated." + " Either avoid passing a time unit to use the default value ('us')," + " or pass a valid time unit instead ('ms', 'us', 'ns').", + version="0.20.11", + ) + time_unit = "us" - if self.time_unit not in ("ms", "us", "ns"): + if time_unit not in ("ms", "us", "ns"): msg = ( "invalid `time_unit`" - f"\n\nExpected one of {{'ns','us','ms'}}, got {self.time_unit!r}." + f"\n\nExpected one of {{'ns','us','ms'}}, got {time_unit!r}." ) raise ValueError(msg) + if isinstance(time_zone, timezone): + time_zone = str(time_zone) + + self.time_unit = time_unit + self.time_zone = time_zone + def __eq__(self, other: PolarsDataType) -> bool: # type: ignore[override] # allow comparing object instances to class if type(other) is DataTypeClass and issubclass(other, Datetime): @@ -461,27 +485,33 @@ def __repr__(self) -> str: class Duration(TemporalType): - """Time duration/delta type.""" + """ + Data type representing a time duration. + + Parameters + ---------- + time_unit : {'us', 'ns', 'ms'} + Unit of time. Defaults to `'us'` (microseconds). + + Notes + ----- + The underlying representation of this type is a 64-bit signed integer. + The integer indicates an amount of time units and can be negative to indicate + negative time offsets. + """ time_unit: TimeUnit | None = None def __init__(self, time_unit: TimeUnit = "us"): - """ - Time duration/delta type. - - Parameters - ---------- - time_unit : {'us', 'ns', 'ms'} - Unit of time. - """ - self.time_unit = time_unit - if self.time_unit not in ("ms", "us", "ns"): + if time_unit not in ("ms", "us", "ns"): msg = ( "invalid `time_unit`" - f"\n\nExpected one of {{'ns','us','ms'}}, got {self.time_unit!r}." + f"\n\nExpected one of {{'ns','us','ms'}}, got {time_unit!r}." ) raise ValueError(msg) + self.time_unit = time_unit + def __eq__(self, other: PolarsDataType) -> bool: # type: ignore[override] # allow comparing object instances to class if type(other) is DataTypeClass and issubclass(other, Duration): @@ -505,9 +535,9 @@ class Categorical(DataType): Parameters ---------- - ordering : {'lexical', 'physical'} - Ordering by order of appearance (physical, default) - or string value (lexical). + ordering : {'lexical', 'physical'} + Ordering by order of appearance (`'physical'`, default) + or string value (`'lexical'`). """ ordering: CategoricalOrdering | None @@ -542,22 +572,19 @@ class Enum(DataType): This functionality is considered **unstable**. It is a work-in-progress feature and may not always work as expected. It may be changed at any point without it being considered a breaking change. + + Parameters + ---------- + categories + The categories in the dataset. Categories must be strings. """ categories: Series def __init__(self, categories: Series | Iterable[str]): - """ - A fixed set categorical encoding of a set of strings. - - Parameters - ---------- - categories - Valid categories in the dataset. - """ # Issuing the warning on `__init__` does not trigger when the class is used # without being instantiated, but it's better than nothing - from polars.utils.unstable import issue_unstable_warning + from polars._utils.unstable import issue_unstable_warning issue_unstable_warning( "The Enum data type is considered unstable." @@ -604,50 +631,49 @@ def __repr__(self) -> str: class Object(DataType): - """Type for wrapping arbitrary Python objects.""" + """Data type for wrapping arbitrary Python objects.""" class Null(DataType): - """Type representing Null / None values.""" + """Data type representing null values.""" class Unknown(DataType): - """Type representing Datatype values that could not be determined statically.""" + """Type representing DataType values that could not be determined statically.""" class List(NestedType): - """Variable length list type.""" + """ + Variable length list type. + + Parameters + ---------- + inner + The `DataType` of the values within each list. + + Examples + -------- + >>> df = pl.DataFrame( + ... { + ... "integer_lists": [[1, 2], [3, 4]], + ... "float_lists": [[1.0, 2.0], [3.0, 4.0]], + ... } + ... ) + >>> df + shape: (2, 2) + ┌───────────────┬─────────────┐ + │ integer_lists ┆ float_lists │ + │ --- ┆ --- │ + │ list[i64] ┆ list[f64] │ + ╞═══════════════╪═════════════╡ + │ [1, 2] ┆ [1.0, 2.0] │ + │ [3, 4] ┆ [3.0, 4.0] │ + └───────────────┴─────────────┘ + """ inner: PolarsDataType | None = None def __init__(self, inner: PolarsDataType | PythonDataType): - """ - Variable length list type. - - Parameters - ---------- - inner - The `DataType` of the values within each list. - - Examples - -------- - >>> df = pl.DataFrame( - ... { - ... "integer_lists": [[1, 2], [3, 4]], - ... "float_lists": [[1.0, 2.0], [3.0, 4.0]], - ... } - ... ) - >>> df - shape: (2, 2) - ┌───────────────┬─────────────┐ - │ integer_lists ┆ float_lists │ - │ --- ┆ --- │ - │ list[i64] ┆ list[f64] │ - ╞═══════════════╪═════════════╡ - │ [1, 2] ┆ [1.0, 2.0] │ - │ [3, 4] ┆ [3.0, 4.0] │ - └───────────────┴─────────────┘ - """ self.inner = polars.datatypes.py_type_to_dtype(inner) def __eq__(self, other: PolarsDataType) -> bool: # type: ignore[override] @@ -677,33 +703,32 @@ def __repr__(self) -> str: class Array(NestedType): - """Fixed length list type.""" + """ + Fixed length list type. + + Parameters + ---------- + inner + The `DataType` of the values within each array. + width + The length of the arrays. + + Examples + -------- + >>> s = pl.Series("a", [[1, 2], [4, 3]], dtype=pl.Array(pl.Int64, 2)) + >>> s + shape: (2,) + Series: 'a' [array[i64, 2]] + [ + [1, 2] + [4, 3] + ] + """ inner: PolarsDataType | None = None width: int def __init__(self, inner: PolarsDataType | PythonDataType, width: int): - """ - Fixed length list type. - - Parameters - ---------- - inner - The `DataType` of the values within each array. - width - The length of the arrays. - - Examples - -------- - >>> s = pl.Series("a", [[1, 2], [4, 3]], dtype=pl.Array(pl.Int64, 2)) - >>> s - shape: (2,) - Series: 'a' [array[i64, 2]] - [ - [1, 2] - [4, 3] - ] - """ self.inner = polars.datatypes.py_type_to_dtype(inner) self.width = width @@ -736,19 +761,21 @@ def __repr__(self) -> str: class Field: - """Definition of a single field within a `Struct` DataType.""" + """ + Definition of a single field within a `Struct` DataType. - def __init__(self, name: str, dtype: PolarsDataType): - """ - Definition of a single field within a `Struct` DataType. + Parameters + ---------- + name + The name of the field within its parent `Struct`. + dtype + The `DataType` of the field's values. + """ - Parameters - ---------- - name - The name of the field within its parent `Struct` - dtype - The `DataType` of the field's values - """ + name: str + dtype: PolarsDataType + + def __init__(self, name: str, dtype: PolarsDataType): self.name = name self.dtype = polars.datatypes.py_type_to_dtype(dtype) @@ -764,49 +791,46 @@ def __repr__(self) -> str: class Struct(NestedType): - """Struct composite type.""" + """ + Struct composite type. + + Parameters + ---------- + fields + The fields that make up the struct. Can be either a sequence of Field + objects or a mapping of column names to data types. + + Examples + -------- + Initialize using a dictionary: + + >>> dtype = pl.Struct({"a": pl.Int8, "b": pl.List(pl.String)}) + >>> dtype + Struct({'a': Int8, 'b': List(String)}) + + Initialize using a list of Field objects: + + >>> dtype = pl.Struct([pl.Field("a", pl.Int8), pl.Field("b", pl.List(pl.String))]) + >>> dtype + Struct({'a': Int8, 'b': List(String)}) + + When initializing a Series, Polars can infer a struct data type from the data. + + >>> s = pl.Series([{"a": 1, "b": ["x", "y"]}, {"a": 2, "b": ["z"]}]) + >>> s + shape: (2,) + Series: '' [struct[2]] + [ + {1,["x", "y"]} + {2,["z"]} + ] + >>> s.dtype + Struct({'a': Int64, 'b': List(String)}) + """ fields: list[Field] def __init__(self, fields: Sequence[Field] | SchemaDict): - """ - Struct composite type. - - Parameters - ---------- - fields - The fields that make up the struct. Can be either a sequence of Field - objects or a mapping of column names to data types. - - Examples - -------- - Initialize using a dictionary: - - >>> dtype = pl.Struct({"a": pl.Int8, "b": pl.List(pl.String)}) - >>> dtype - Struct({'a': Int8, 'b': List(String)}) - - Initialize using a list of Field objects: - - >>> dtype = pl.Struct( - ... [pl.Field("a", pl.Int8), pl.Field("b", pl.List(pl.String))] - ... ) - >>> dtype - Struct({'a': Int8, 'b': List(String)}) - - When initializing a Series, Polars can infer a struct data type from the data. - - >>> s = pl.Series([{"a": 1, "b": ["x", "y"]}, {"a": 2, "b": ["z"]}]) - >>> s - shape: (2,) - Series: '' [struct[2]] - [ - {1,["x", "y"]} - {2,["z"]} - ] - >>> s.dtype - Struct({'a': Int64, 'b': List(String)}) - """ if isinstance(fields, Mapping): self.fields = [Field(name, dtype) for name, dtype in fields.items()] else: diff --git a/py-polars/polars/datatypes/convert.py b/py-polars/polars/datatypes/convert.py index 9cd0ab6c7cfba..d44bf9672908c 100644 --- a/py-polars/polars/datatypes/convert.py +++ b/py-polars/polars/datatypes/convert.py @@ -66,7 +66,7 @@ if TYPE_CHECKING: from typing import Literal - from polars.type_aliases import PolarsDataType, PythonDataType, SchemaDict + from polars.type_aliases import PolarsDataType, PythonDataType, SchemaDict, TimeUnit PY_STR_TO_DTYPE: SchemaDict = { @@ -246,12 +246,12 @@ def DTYPE_TO_CTYPE(self) -> dict[PolarsDataType, Any]: Int8: ctypes.c_int8, Int16: ctypes.c_int16, Int32: ctypes.c_int32, - Date: ctypes.c_int32, Int64: ctypes.c_int64, Float32: ctypes.c_float, Float64: ctypes.c_double, Datetime: ctypes.c_int64, Duration: ctypes.c_int64, + Date: ctypes.c_int32, Time: ctypes.c_int64, } @@ -298,6 +298,8 @@ def NUMPY_KIND_AND_ITEMSIZE_TO_DTYPE(self) -> dict[tuple[str, int], PolarsDataTy ("u", 8): UInt64, ("f", 4): Float32, ("f", 8): Float64, + ("m", 8): Duration, + ("M", 8): Datetime, } @property @@ -368,15 +370,13 @@ def dtype_to_py_type(dtype: PolarsDataType) -> PythonDataType: @overload def py_type_to_dtype( data_type: Any, *, raise_unmatched: Literal[True] = ... -) -> PolarsDataType: - ... +) -> PolarsDataType: ... @overload def py_type_to_dtype( data_type: Any, *, raise_unmatched: Literal[False] -) -> PolarsDataType | None: - ... +) -> PolarsDataType | None: ... def py_type_to_dtype( @@ -469,6 +469,8 @@ def numpy_char_code_to_dtype(dtype_char: str) -> PolarsDataType: dtype = np.dtype(dtype_char) if dtype.kind == "U": return String + elif dtype.kind == "S": + return Binary try: return DataTypeMappings.NUMPY_KIND_AND_ITEMSIZE_TO_DTYPE[ (dtype.kind, dtype.itemsize) @@ -481,17 +483,18 @@ def numpy_char_code_to_dtype(dtype_char: str) -> PolarsDataType: def maybe_cast(el: Any, dtype: PolarsDataType) -> Any: """Try casting a value to a value that is valid for the given Polars dtype.""" # cast el if it doesn't match - from polars.utils.convert import ( - _datetime_to_pl_timestamp, - _timedelta_to_pl_timedelta, + from polars._utils.convert import ( + datetime_to_int, + timedelta_to_int, ) + time_unit: TimeUnit if isinstance(el, datetime): - time_unit = getattr(dtype, "time_unit", None) - return _datetime_to_pl_timestamp(el, time_unit) + time_unit = getattr(dtype, "time_unit", "us") + return datetime_to_int(el, time_unit) elif isinstance(el, timedelta): - time_unit = getattr(dtype, "time_unit", None) - return _timedelta_to_pl_timedelta(el, time_unit) + time_unit = getattr(dtype, "time_unit", "us") + return timedelta_to_int(el, time_unit) py_type = dtype_to_py_type(dtype) if not isinstance(el, py_type): diff --git a/py-polars/polars/dependencies.py b/py-polars/polars/dependencies.py index 1cc61eb4609cf..61d37c3877a80 100644 --- a/py-polars/polars/dependencies.py +++ b/py-polars/polars/dependencies.py @@ -244,8 +244,8 @@ def import_optional( min_version : {str, tuple[int]}, optional If a minimum module version is required, specify it here. """ + from polars._utils.various import parse_version from polars.exceptions import ModuleUpgradeRequired - from polars.utils.various import parse_version try: module = import_module(module_name) diff --git a/py-polars/polars/expr/array.py b/py-polars/polars/expr/array.py index b228b7b562b74..6972d5e1f0620 100644 --- a/py-polars/polars/expr/array.py +++ b/py-polars/polars/expr/array.py @@ -2,8 +2,8 @@ from typing import TYPE_CHECKING, Callable, Sequence -from polars.utils._parse_expr_input import parse_as_expression -from polars.utils._wrap import wrap_expr +from polars._utils.parse_expr_input import parse_as_expression +from polars._utils.wrap import wrap_expr if TYPE_CHECKING: from datetime import date, datetime, time diff --git a/py-polars/polars/expr/binary.py b/py-polars/polars/expr/binary.py index 461c188822a71..978df94bb0635 100644 --- a/py-polars/polars/expr/binary.py +++ b/py-polars/polars/expr/binary.py @@ -2,8 +2,8 @@ from typing import TYPE_CHECKING -from polars.utils._parse_expr_input import parse_as_expression -from polars.utils._wrap import wrap_expr +from polars._utils.parse_expr_input import parse_as_expression +from polars._utils.wrap import wrap_expr if TYPE_CHECKING: from polars import Expr @@ -162,8 +162,8 @@ def starts_with(self, prefix: IntoExpr) -> Expr: return wrap_expr(self._pyexpr.bin_starts_with(prefix)) def decode(self, encoding: TransferEncoding, *, strict: bool = True) -> Expr: - """ - Decode a value using the provided encoding. + r""" + Decode values using the provided encoding. Parameters ---------- @@ -172,6 +172,33 @@ def decode(self, encoding: TransferEncoding, *, strict: bool = True) -> Expr: strict Raise an error if the underlying value cannot be decoded, otherwise mask out with a null value. + + Returns + ------- + Expr + Expression of data type :class:`String`. + + Examples + -------- + >>> colors = pl.DataFrame( + ... { + ... "name": ["black", "yellow", "blue"], + ... "code": [b"\x00\x00\x00", b"\xff\xff\x00", b"\x00\x00\xff"], + ... } + ... ) + >>> colors.with_columns( + ... pl.col("code").bin.encode("hex").alias("encoded"), + ... ) + shape: (3, 3) + ┌────────┬─────────────────┬─────────┐ + │ name ┆ code ┆ encoded │ + │ --- ┆ --- ┆ --- │ + │ str ┆ binary ┆ str │ + ╞════════╪═════════════════╪═════════╡ + │ black ┆ b"\x00\x00\x00" ┆ 000000 │ + │ yellow ┆ b"\xff\xff\x00" ┆ ffff00 │ + │ blue ┆ b"\x00\x00\xff" ┆ 0000ff │ + └────────┴─────────────────┴─────────┘ """ if encoding == "hex": return wrap_expr(self._pyexpr.bin_hex_decode(strict)) @@ -193,30 +220,29 @@ def encode(self, encoding: TransferEncoding) -> Expr: Returns ------- Expr - Expression of data type :class:`String` with values encoded using provided - encoding. + Expression of data type :class:`String`. Examples -------- >>> colors = pl.DataFrame( ... { - ... "name": ["black", "yellow", "blue"], + ... "color": ["black", "yellow", "blue"], ... "code": [b"\x00\x00\x00", b"\xff\xff\x00", b"\x00\x00\xff"], ... } ... ) >>> colors.with_columns( - ... pl.col("code").bin.encode("hex").alias("code_encoded_hex"), + ... pl.col("code").bin.encode("hex").alias("encoded"), ... ) shape: (3, 3) - ┌────────┬─────────────────┬──────────────────┐ - │ name ┆ code ┆ code_encoded_hex │ - │ --- ┆ --- ┆ --- │ - │ str ┆ binary ┆ str │ - ╞════════╪═════════════════╪══════════════════╡ - │ black ┆ b"\x00\x00\x00" ┆ 000000 │ - │ yellow ┆ b"\xff\xff\x00" ┆ ffff00 │ - │ blue ┆ b"\x00\x00\xff" ┆ 0000ff │ - └────────┴─────────────────┴──────────────────┘ + ┌────────┬─────────────────┬─────────┐ + │ color ┆ code ┆ encoded │ + │ --- ┆ --- ┆ --- │ + │ str ┆ binary ┆ str │ + ╞════════╪═════════════════╪═════════╡ + │ black ┆ b"\x00\x00\x00" ┆ 000000 │ + │ yellow ┆ b"\xff\xff\x00" ┆ ffff00 │ + │ blue ┆ b"\x00\x00\xff" ┆ 0000ff │ + └────────┴─────────────────┴─────────┘ """ if encoding == "hex": return wrap_expr(self._pyexpr.bin_hex_encode()) diff --git a/py-polars/polars/expr/categorical.py b/py-polars/polars/expr/categorical.py index 89ecef5188ea2..ca00114c4e364 100644 --- a/py-polars/polars/expr/categorical.py +++ b/py-polars/polars/expr/categorical.py @@ -2,8 +2,8 @@ from typing import TYPE_CHECKING -from polars.utils._wrap import wrap_expr -from polars.utils.deprecation import deprecate_function +from polars._utils.deprecation import deprecate_function +from polars._utils.wrap import wrap_expr if TYPE_CHECKING: from polars import Expr diff --git a/py-polars/polars/expr/datetime.py b/py-polars/polars/expr/datetime.py index e1cb23331b582..5c28476018fd8 100644 --- a/py-polars/polars/expr/datetime.py +++ b/py-polars/polars/expr/datetime.py @@ -5,18 +5,18 @@ import polars._reexport as pl from polars import functions as F -from polars.datatypes import DTYPE_TEMPORAL_UNITS, Date, Int32 -from polars.utils._parse_expr_input import parse_as_expression -from polars.utils._wrap import wrap_expr -from polars.utils.convert import _timedelta_to_pl_duration -from polars.utils.deprecation import ( +from polars._utils.convert import parse_as_duration_string +from polars._utils.deprecation import ( deprecate_function, deprecate_renamed_function, deprecate_saturating, issue_deprecation_warning, rename_use_earliest_to_ambiguous, ) -from polars.utils.unstable import unstable +from polars._utils.parse_expr_input import parse_as_expression +from polars._utils.unstable import unstable +from polars._utils.wrap import wrap_expr +from polars.datatypes import DTYPE_TEMPORAL_UNITS, Date, Int32 if TYPE_CHECKING: from datetime import timedelta @@ -188,7 +188,7 @@ def truncate( every = deprecate_saturating(every) offset = deprecate_saturating(offset) if not isinstance(every, pl.Expr): - every = _timedelta_to_pl_duration(every) + every = parse_as_duration_string(every) if use_earliest is not None: issue_deprecation_warning( @@ -208,7 +208,7 @@ def truncate( return wrap_expr( self._pyexpr.dt_truncate( every, - _timedelta_to_pl_duration(offset), + parse_as_duration_string(offset), ) ) @@ -249,7 +249,7 @@ def round( - `'earliest'`: use the earliest datetime - `'latest'`: use the latest datetime - .. deprecated: 0.19.3 + .. deprecated:: 0.19.3 This is now auto-inferred, you can safely remove this argument. Returns @@ -349,8 +349,8 @@ def round( return wrap_expr( self._pyexpr.dt_round( - _timedelta_to_pl_duration(every), - _timedelta_to_pl_duration(offset), + parse_as_duration_string(every), + parse_as_duration_string(offset), ) ) @@ -1629,6 +1629,7 @@ def replace_time_zone( - `'raise'` (default): raise - `'earliest'`: use the earliest datetime - `'latest'`: use the latest datetime + - `'null'`: set to null Examples -------- @@ -1665,7 +1666,7 @@ def replace_time_zone( │ 2020-07-01 01:00:00 BST ┆ 2020-07-01 01:00:00 CEST │ └─────────────────────────────┴────────────────────────────────┘ - You can use `use_earliest` to deal with ambiguous datetimes: + You can use `ambiguous` to deal with ambiguous datetimes: >>> dates = [ ... "2018-10-28 01:30", diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index d5e5f4e651e75..53f202b52572a 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -6,6 +6,8 @@ import warnings from datetime import timedelta from functools import reduce +from io import BytesIO, StringIO +from pathlib import Path from typing import ( TYPE_CHECKING, Any, @@ -23,6 +25,29 @@ import polars._reexport as pl from polars import functions as F +from polars._utils.convert import negate_duration_string, parse_as_duration_string +from polars._utils.deprecation import ( + deprecate_function, + deprecate_nonkeyword_arguments, + deprecate_renamed_function, + deprecate_renamed_parameter, + deprecate_saturating, + issue_deprecation_warning, +) +from polars._utils.parse_expr_input import ( + parse_as_expression, + parse_as_list_of_expressions, + parse_predicates_constraints_as_expression, +) +from polars._utils.unstable import issue_unstable_warning, unstable +from polars._utils.various import ( + BUILDING_SPHINX_DOCS, + find_stacklevel, + no_default, + normalize_filepath, + sphinx_accessor, + warn_null_comparison, +) from polars.datatypes import ( Int64, is_polars_dtype, @@ -41,28 +66,6 @@ from polars.expr.string import ExprStringNameSpace from polars.expr.struct import ExprStructNameSpace from polars.meta import thread_pool_size -from polars.utils._parse_expr_input import ( - parse_as_expression, - parse_as_list_of_expressions, - parse_predicates_constraints_as_expression, -) -from polars.utils.convert import _negate_duration, _timedelta_to_pl_duration -from polars.utils.deprecation import ( - deprecate_function, - deprecate_nonkeyword_arguments, - deprecate_renamed_function, - deprecate_renamed_parameter, - deprecate_saturating, - issue_deprecation_warning, -) -from polars.utils.unstable import issue_unstable_warning, unstable -from polars.utils.various import ( - BUILDING_SPHINX_DOCS, - find_stacklevel, - no_default, - sphinx_accessor, - warn_null_comparison, -) with contextlib.suppress(ImportError): # Module not available when building docs from polars.polars import arg_where as py_arg_where @@ -72,8 +75,12 @@ if TYPE_CHECKING: import sys + from io import IOBase from polars import DataFrame, LazyFrame, Series + from polars._utils.various import ( + NoDefault, + ) from polars.type_aliases import ( ClosedInterval, FillNullStrategy, @@ -90,9 +97,6 @@ TemporalLiteral, WindowMappingStrategy, ) - from polars.utils.various import ( - NoDefault, - ) if sys.version_info >= (3, 11): from typing import Concatenate, ParamSpec, Self @@ -329,17 +333,36 @@ def function(s: Series) -> Series: # pragma: no cover return root_expr.map_batches(function, is_elementwise=True).meta.undo_aliases() @classmethod - def from_json(cls, value: str) -> Self: + def deserialize(cls, source: str | Path | IOBase) -> Self: """ - Read an expression from a JSON encoded string to construct an Expression. + Read an expression from a JSON file. Parameters ---------- - value - JSON encoded string value + source + Path to a file or a file-like object (by file-like object, we refer to + objects that have a `read()` method, such as a file handler (e.g. + via builtin `open` function) or `BytesIO`). + + See Also + -------- + Expr.meta.serialize + + Examples + -------- + >>> from io import StringIO + >>> expr = pl.col("foo").sum().over("bar") + >>> json = expr.meta.serialize() + >>> pl.Expr.deserialize(StringIO(json)) # doctest: +ELLIPSIS + """ + if isinstance(source, StringIO): + source = BytesIO(source.getvalue().encode()) + elif isinstance(source, (str, Path)): + source = normalize_filepath(source) + expr = cls.__new__(cls) - expr._pyexpr = PyExpr.meta_read_json(value) + expr._pyexpr = PyExpr.deserialize(source) return expr def to_physical(self) -> Self: @@ -3299,7 +3322,7 @@ def rolling( check_sorted: bool = True, ) -> Self: """ - Create rolling groups based on a time, Int32, or Int64 column. + Create rolling groups based on a temporal or integer column. If you have a time series ``, then by default the windows created will be @@ -3339,11 +3362,6 @@ def rolling( not be 24 hours, due to daylight savings). Similarly for "calendar week", "calendar month", "calendar quarter", and "calendar year". - In case of a rolling operation on an integer column, the windows are defined by: - - - "1i" # length 1 - - "10i" # length 10 - Parameters ---------- index_column @@ -3351,8 +3369,8 @@ def rolling( Often of type Date/Datetime. This column must be sorted in ascending order. In case of a rolling group by on indices, dtype needs to be one of - {Int32, Int64}. Note that Int32 gets temporarily cast to Int64, so if - performance matters use an Int64 column. + {UInt32, UInt64, Int32, Int64}. Note that the first three get temporarily + cast to Int64, so if performance matters use an Int64 column. period length of the window - must be non-negative offset @@ -3401,10 +3419,10 @@ def rolling( period = deprecate_saturating(period) offset = deprecate_saturating(offset) if offset is None: - offset = _negate_duration(_timedelta_to_pl_duration(period)) + offset = negate_duration_string(parse_as_duration_string(period)) - period = _timedelta_to_pl_duration(period) - offset = _timedelta_to_pl_duration(offset) + period = parse_as_duration_string(period) + offset = parse_as_duration_string(offset) return self._from_pyexpr( self._pyexpr.rolling(index_column, period, offset, closed, check_sorted) @@ -4339,7 +4357,7 @@ def map_elements( ) # input x: Series of type list containing the group values - from polars.utils.udfs import warn_on_inefficient_map + from polars._utils.udfs import warn_on_inefficient_map root_names = self.meta.root_names() if len(root_names) > 0: @@ -5716,16 +5734,13 @@ def rolling_min( If `by` has not been specified (the default), the window at a given row will include the row itself, and the `window_size - 1` elements before it. - If you pass a `by` column ``, then `closed="left"` - means the windows will be: + If you pass a `by` column ``, then `closed="right"` + (the default) means the windows will be: - - [t_0 - window_size, t_0) - - [t_1 - window_size, t_1) + - (t_0 - window_size, t_0] + - (t_1 - window_size, t_1] - ... - - [t_n - window_size, t_n) - - With `closed="right"`, the left endpoint is not included and the right - endpoint is included. + - (t_n - window_size, t_n] Parameters ---------- @@ -5931,16 +5946,13 @@ def rolling_max( If `by` has not been specified (the default), the window at a given row will include the row itself, and the `window_size - 1` elements before it. - If you pass a `by` column ``, then `closed="left"` - means the windows will be: + If you pass a `by` column ``, then `closed="right"` + (the default) means the windows will be: - - [t_0 - window_size, t_0) - - [t_1 - window_size, t_1) + - (t_0 - window_size, t_0] + - (t_1 - window_size, t_1] - ... - - [t_n - window_size, t_n) - - With `closed="right"`, the left endpoint is not included and the right - endpoint is included. + - (t_n - window_size, t_n] Parameters ---------- @@ -6171,16 +6183,13 @@ def rolling_mean( If `by` has not been specified (the default), the window at a given row will include the row itself, and the `window_size - 1` elements before it. - If you pass a `by` column ``, then `closed="left"` - means the windows will be: + If you pass a `by` column ``, then `closed="right"` + (the default) means the windows will be: - - [t_0 - window_size, t_0) - - [t_1 - window_size, t_1) + - (t_0 - window_size, t_0] + - (t_1 - window_size, t_1] - ... - - [t_n - window_size, t_n) - - With `closed="right"`, the left endpoint is not included and the right - endpoint is included. + - (t_n - window_size, t_n] Parameters ---------- @@ -6421,16 +6430,13 @@ def rolling_sum( If `by` has not been specified (the default), the window at a given row will include the row itself, and the `window_size - 1` elements before it. - If you pass a `by` column ``, then `closed="left"` - means the windows will be: + If you pass a `by` column ``, then `closed="right"` + (the default) means the windows will be: - - [t_0 - window_size, t_0) - - [t_1 - window_size, t_1) + - (t_0 - window_size, t_0] + - (t_1 - window_size, t_1] - ... - - [t_n - window_size, t_n) - - With `closed="right"`, the left endpoint is not included and the right - endpoint is included. + - (t_n - window_size, t_n] Parameters ---------- @@ -6666,8 +6672,6 @@ def rolling_std( - ... - [t_n - window_size, t_n) - With `closed="right"`, the left endpoint is not included and the right - endpoint is included. Parameters ---------- @@ -6830,7 +6834,7 @@ def rolling_std( │ u32 ┆ datetime[μs] ┆ f64 │ ╞═══════╪═════════════════════╪═════════════════╡ │ 0 ┆ 2001-01-01 00:00:00 ┆ null │ - │ 1 ┆ 2001-01-01 01:00:00 ┆ 0.0 │ + │ 1 ┆ 2001-01-01 01:00:00 ┆ null │ │ 2 ┆ 2001-01-01 02:00:00 ┆ 0.707107 │ │ 3 ┆ 2001-01-01 03:00:00 ┆ 0.707107 │ │ 4 ┆ 2001-01-01 04:00:00 ┆ 0.707107 │ @@ -6855,7 +6859,7 @@ def rolling_std( │ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[μs] ┆ f64 │ ╞═══════╪═════════════════════╪═════════════════╡ - │ 0 ┆ 2001-01-01 00:00:00 ┆ 0.0 │ + │ 0 ┆ 2001-01-01 00:00:00 ┆ null │ │ 1 ┆ 2001-01-01 01:00:00 ┆ 0.707107 │ │ 2 ┆ 2001-01-01 02:00:00 ┆ 1.0 │ │ 3 ┆ 2001-01-01 03:00:00 ┆ 1.0 │ @@ -6908,16 +6912,13 @@ def rolling_var( If `by` has not been specified (the default), the window at a given row will include the row itself, and the `window_size - 1` elements before it. - If you pass a `by` column ``, then `closed="left"` - means the windows will be: + If you pass a `by` column ``, then `closed="right"` + (the default) means the windows will be: - - [t_0 - window_size, t_0) - - [t_1 - window_size, t_1) + - (t_0 - window_size, t_0] + - (t_1 - window_size, t_1] - ... - - [t_n - window_size, t_n) - - With `closed="right"`, the left endpoint is not included and the right - endpoint is included. + - (t_n - window_size, t_n] Parameters ---------- @@ -7080,7 +7081,7 @@ def rolling_var( │ u32 ┆ datetime[μs] ┆ f64 │ ╞═══════╪═════════════════════╪═════════════════╡ │ 0 ┆ 2001-01-01 00:00:00 ┆ null │ - │ 1 ┆ 2001-01-01 01:00:00 ┆ 0.0 │ + │ 1 ┆ 2001-01-01 01:00:00 ┆ null │ │ 2 ┆ 2001-01-01 02:00:00 ┆ 0.5 │ │ 3 ┆ 2001-01-01 03:00:00 ┆ 0.5 │ │ 4 ┆ 2001-01-01 04:00:00 ┆ 0.5 │ @@ -7105,7 +7106,7 @@ def rolling_var( │ --- ┆ --- ┆ --- │ │ u32 ┆ datetime[μs] ┆ f64 │ ╞═══════╪═════════════════════╪═════════════════╡ - │ 0 ┆ 2001-01-01 00:00:00 ┆ 0.0 │ + │ 0 ┆ 2001-01-01 00:00:00 ┆ null │ │ 1 ┆ 2001-01-01 01:00:00 ┆ 0.5 │ │ 2 ┆ 2001-01-01 02:00:00 ┆ 1.0 │ │ 3 ┆ 2001-01-01 03:00:00 ┆ 1.0 │ @@ -7165,8 +7166,6 @@ def rolling_median( - ... - [t_n - window_size, t_n) - With `closed="right"`, the left endpoint is not included and the right - endpoint is included. Parameters ---------- @@ -7319,16 +7318,13 @@ def rolling_quantile( If `by` has not been specified (the default), the window at a given row will include the row itself, and the `window_size - 1` elements before it. - If you pass a `by` column ``, then `closed="left"` - means the windows will be: + If you pass a `by` column ``, then `closed="right"` + (the default) means the windows will be: - - [t_0 - window_size, t_0) - - [t_1 - window_size, t_1) + - (t_0 - window_size, t_0] + - (t_1 - window_size, t_1] - ... - - [t_n - window_size, t_n) - - With `closed="right"`, the left endpoint is not included and the right - endpoint is included. + - (t_n - window_size, t_n] Parameters ---------- @@ -8558,7 +8554,7 @@ def ewm_mean( *, adjust: bool = True, min_periods: int = 1, - ignore_nulls: bool = True, + ignore_nulls: bool | None = None, ) -> Self: r""" Exponentially-weighted moving average. @@ -8587,7 +8583,7 @@ def ewm_mean( Divide by decaying adjustment factor in beginning periods to account for imbalance in relative weightings - - When `adjust=True` the EW function is calculated + - When `adjust=True` (the default) the EW function is calculated using weights :math:`w_i = (1 - \alpha)^i` - When `adjust=False` the EW function is calculated recursively by @@ -8601,7 +8597,7 @@ def ewm_mean( ignore_nulls Ignore missing values when calculating weights. - - When `ignore_nulls=False` (default), weights are based on absolute + - When `ignore_nulls=False`, weights are based on absolute positions. For example, the weights of :math:`x_0` and :math:`x_2` used in calculating the final weighted average of @@ -8609,7 +8605,7 @@ def ewm_mean( :math:`(1-\alpha)^2` and :math:`1` if `adjust=True`, and :math:`(1-\alpha)^2` and :math:`\alpha` if `adjust=False`. - - When `ignore_nulls=True`, weights are based + - When `ignore_nulls=True` (current default), weights are based on relative positions. For example, the weights of :math:`x_0` and :math:`x_2` used in calculating the final weighted average of [:math:`x_0`, None, :math:`x_2`] are @@ -8619,7 +8615,7 @@ def ewm_mean( Examples -------- >>> df = pl.DataFrame({"a": [1, 2, 3]}) - >>> df.select(pl.col("a").ewm_mean(com=1)) + >>> df.select(pl.col("a").ewm_mean(com=1, ignore_nulls=False)) shape: (3, 1) ┌──────────┐ │ a │ @@ -8631,6 +8627,16 @@ def ewm_mean( │ 2.428571 │ └──────────┘ """ + if ignore_nulls is None: + issue_deprecation_warning( + "The default value for `ignore_nulls` for `ewm` methods" + " will change from True to False in the next breaking release." + " Explicitly set `ignore_nulls=True` to keep the existing behavior" + " and silence this warning.", + version="0.20.11", + ) + ignore_nulls = True + alpha = _prepare_alpha(com, span, half_life, alpha) return self._from_pyexpr( self._pyexpr.ewm_mean(alpha, adjust, min_periods, ignore_nulls) @@ -8647,7 +8653,7 @@ def ewm_std( adjust: bool = True, bias: bool = False, min_periods: int = 1, - ignore_nulls: bool = True, + ignore_nulls: bool | None = None, ) -> Self: r""" Exponentially-weighted moving standard deviation. @@ -8676,7 +8682,7 @@ def ewm_std( Divide by decaying adjustment factor in beginning periods to account for imbalance in relative weightings - - When `adjust=True` the EW function is calculated + - When `adjust=True` (the default) the EW function is calculated using weights :math:`w_i = (1 - \alpha)^i` - When `adjust=False` the EW function is calculated recursively by @@ -8693,7 +8699,7 @@ def ewm_std( ignore_nulls Ignore missing values when calculating weights. - - When `ignore_nulls=False` (default), weights are based on absolute + - When `ignore_nulls=False`, weights are based on absolute positions. For example, the weights of :math:`x_0` and :math:`x_2` used in calculating the final weighted average of @@ -8701,7 +8707,7 @@ def ewm_std( :math:`(1-\alpha)^2` and :math:`1` if `adjust=True`, and :math:`(1-\alpha)^2` and :math:`\alpha` if `adjust=False`. - - When `ignore_nulls=True`, weights are based + - When `ignore_nulls=True` (current default), weights are based on relative positions. For example, the weights of :math:`x_0` and :math:`x_2` used in calculating the final weighted average of [:math:`x_0`, None, :math:`x_2`] are @@ -8711,7 +8717,7 @@ def ewm_std( Examples -------- >>> df = pl.DataFrame({"a": [1, 2, 3]}) - >>> df.select(pl.col("a").ewm_std(com=1)) + >>> df.select(pl.col("a").ewm_std(com=1, ignore_nulls=False)) shape: (3, 1) ┌──────────┐ │ a │ @@ -8723,6 +8729,16 @@ def ewm_std( │ 0.963624 │ └──────────┘ """ + if ignore_nulls is None: + issue_deprecation_warning( + "The default value for `ignore_nulls` for `ewm` methods" + " will change from True to False in the next breaking release." + " Explicitly set `ignore_nulls=True` to keep the existing behavior" + " and silence this warning.", + version="0.20.11", + ) + ignore_nulls = True + alpha = _prepare_alpha(com, span, half_life, alpha) return self._from_pyexpr( self._pyexpr.ewm_std(alpha, adjust, bias, min_periods, ignore_nulls) @@ -8739,7 +8755,7 @@ def ewm_var( adjust: bool = True, bias: bool = False, min_periods: int = 1, - ignore_nulls: bool = True, + ignore_nulls: bool | None = None, ) -> Self: r""" Exponentially-weighted moving variance. @@ -8768,7 +8784,7 @@ def ewm_var( Divide by decaying adjustment factor in beginning periods to account for imbalance in relative weightings - - When `adjust=True` the EW function is calculated + - When `adjust=True` (the default) the EW function is calculated using weights :math:`w_i = (1 - \alpha)^i` - When `adjust=False` the EW function is calculated recursively by @@ -8785,7 +8801,7 @@ def ewm_var( ignore_nulls Ignore missing values when calculating weights. - - When `ignore_nulls=False` (default), weights are based on absolute + - When `ignore_nulls=False`, weights are based on absolute positions. For example, the weights of :math:`x_0` and :math:`x_2` used in calculating the final weighted average of @@ -8793,7 +8809,7 @@ def ewm_var( :math:`(1-\alpha)^2` and :math:`1` if `adjust=True`, and :math:`(1-\alpha)^2` and :math:`\alpha` if `adjust=False`. - - When `ignore_nulls=True`, weights are based + - When `ignore_nulls=True` (current default), weights are based on relative positions. For example, the weights of :math:`x_0` and :math:`x_2` used in calculating the final weighted average of [:math:`x_0`, None, :math:`x_2`] are @@ -8803,7 +8819,7 @@ def ewm_var( Examples -------- >>> df = pl.DataFrame({"a": [1, 2, 3]}) - >>> df.select(pl.col("a").ewm_var(com=1)) + >>> df.select(pl.col("a").ewm_var(com=1, ignore_nulls=False)) shape: (3, 1) ┌──────────┐ │ a │ @@ -8815,6 +8831,16 @@ def ewm_var( │ 0.928571 │ └──────────┘ """ + if ignore_nulls is None: + issue_deprecation_warning( + "The default value for `ignore_nulls` for `ewm` methods" + " will change from True to False in the next breaking release." + " Explicitly set `ignore_nulls=True` to keep the existing behavior" + " and silence this warning.", + version="0.20.11", + ) + ignore_nulls = True + alpha = _prepare_alpha(com, span, half_life, alpha) return self._from_pyexpr( self._pyexpr.ewm_var(alpha, adjust, bias, min_periods, ignore_nulls) @@ -9237,7 +9263,7 @@ def replace( Accepts expression input. Sequences are parsed as Series, other non-expression inputs are parsed as literals. Also accepts a mapping of values to their replacement as syntactic sugar for - `replace(new=Series(mapping.keys()), old=Series(mapping.values()))`. + `replace(old=Series(mapping.keys()), new=Series(mapping.values()))`. new Value or sequence of values to replace by. Accepts expression input. Sequences are parsed as Series, @@ -9596,6 +9622,9 @@ def shift_and_fill( """ return self.shift(n, fill_value=fill_value) + @deprecate_function( + "Use `polars.plugins.register_plugin_function` instead.", version="0.20.16" + ) def register_plugin( self, *, @@ -9609,20 +9638,26 @@ def register_plugin( cast_to_supertypes: bool = False, pass_name_to_apply: bool = False, changes_length: bool = False, - ) -> Self: + ) -> Expr: """ - Register a shared library as a plugin. + Register a plugin function. - .. warning:: - This is highly unsafe as this will call the C function - loaded by `lib::symbol`. + .. deprecated:: 0.20.16 + Use :func:`polars.plugins.register_plugin_function` instead. - The parameters you give dictate how polars will deal - with the function. Make sure they are correct! + See the `user guide `_ + for more information about plugins. - .. note:: - This functionality is unstable and may change without it - being considered breaking. + Warnings + -------- + This method is deprecated. Use the new `polars.plugins.register_plugin_function` + function instead. + + This is highly unsafe as this will call the C function loaded by + `lib::symbol`. + + The parameters you set dictate how Polars will handle the function. + Make sure they are correct! Parameters ---------- @@ -9651,31 +9686,24 @@ def register_plugin( changes_length For example a `unique` or a `slice` """ + from polars.plugins import register_plugin_function + if args is None: - args = [] + args = [self] else: - args = [parse_as_expression(a) for a in args] - if kwargs is None: - serialized_kwargs = b"" - else: - import pickle - - # Choose the highest protocol supported by https://docs.rs/serde-pickle/latest/serde_pickle/ - serialized_kwargs = pickle.dumps(kwargs, protocol=5) + args = [self, *list(args)] - return self._from_pyexpr( - self._pyexpr.register_plugin( - lib, - symbol, - args, - serialized_kwargs, - is_elementwise, - input_wildcard_expansion, - returns_scalar, - cast_to_supertypes, - pass_name_to_apply, - changes_length, - ) + return register_plugin_function( + plugin_path=lib, + function_name=symbol, + args=args, + kwargs=kwargs, + is_elementwise=is_elementwise, + changes_length=changes_length, + returns_scalar=returns_scalar, + cast_to_supertype=cast_to_supertypes, + input_wildcard_expansion=input_wildcard_expansion, + pass_name_to_apply=pass_name_to_apply, ) @deprecate_renamed_function("register_plugin", version="0.19.12") @@ -9690,7 +9718,7 @@ def _register_plugin( input_wildcard_expansion: bool = False, auto_explode: bool = False, cast_to_supertypes: bool = False, - ) -> Self: + ) -> Expr: return self.register_plugin( lib=lib, symbol=symbol, @@ -9846,6 +9874,29 @@ def map_dict( """ return self.replace(mapping, default=default, return_dtype=return_dtype) + @classmethod + def from_json(cls, value: str) -> Self: + """ + Read an expression from a JSON encoded string to construct an Expression. + + .. deprecated:: 0.20.11 + This method has been renamed to :meth:`deserialize`. + Note that the new method operates on file-like inputs rather than strings. + Enclose your input in `io.StringIO` to keep the same behavior. + + Parameters + ---------- + value + JSON encoded string value + """ + issue_deprecation_warning( + "`Expr.from_json` is deprecated. It has been renamed to `Expr.deserialize`." + " Note that the new method operates on file-like inputs rather than strings." + " Enclose your input in `io.StringIO` to keep the same behavior.", + version="0.20.11", + ) + return cls.deserialize(StringIO(value)) + @property def bin(self) -> ExprBinaryNameSpace: """ @@ -10036,7 +10087,7 @@ def _prepare_rolling_window_args( min_periods = window_size window_size = f"{window_size}i" elif isinstance(window_size, timedelta): - window_size = _timedelta_to_pl_duration(window_size) + window_size = parse_as_duration_string(window_size) if min_periods is None: min_periods = 1 return window_size, min_periods diff --git a/py-polars/polars/expr/list.py b/py-polars/polars/expr/list.py index 71139e65cb668..b81caf6d6113c 100644 --- a/py-polars/polars/expr/list.py +++ b/py-polars/polars/expr/list.py @@ -5,12 +5,12 @@ import polars._reexport as pl from polars import functions as F -from polars.utils._parse_expr_input import parse_as_expression -from polars.utils._wrap import wrap_expr -from polars.utils.deprecation import ( +from polars._utils.deprecation import ( deprecate_renamed_function, deprecate_renamed_parameter, ) +from polars._utils.parse_expr_input import parse_as_expression +from polars._utils.wrap import wrap_expr if TYPE_CHECKING: from datetime import date, datetime, time diff --git a/py-polars/polars/expr/meta.py b/py-polars/polars/expr/meta.py index 4c5e0eb2eb0ce..775405e2594cb 100644 --- a/py-polars/polars/expr/meta.py +++ b/py-polars/polars/expr/meta.py @@ -4,10 +4,13 @@ from pathlib import Path from typing import TYPE_CHECKING, Literal, overload +from polars._utils.deprecation import ( + deprecate_nonkeyword_arguments, + deprecate_renamed_function, +) +from polars._utils.various import normalize_filepath +from polars._utils.wrap import wrap_expr from polars.exceptions import ComputeError -from polars.utils._wrap import wrap_expr -from polars.utils.deprecation import deprecate_nonkeyword_arguments -from polars.utils.various import normalize_filepath if TYPE_CHECKING: from io import IOBase @@ -104,12 +107,10 @@ def is_regex_projection(self) -> bool: return self._pyexpr.meta_is_regex_projection() @overload - def output_name(self, *, raise_if_undetermined: Literal[True] = True) -> str: - ... + def output_name(self, *, raise_if_undetermined: Literal[True] = True) -> str: ... @overload - def output_name(self, *, raise_if_undetermined: Literal[False]) -> str | None: - ... + def output_name(self, *, raise_if_undetermined: Literal[False]) -> str | None: ... def output_name(self, *, raise_if_undetermined: bool = True) -> str | None: """ @@ -218,21 +219,46 @@ def _selector_and(self, other: Expr) -> Expr: return wrap_expr(self._pyexpr._meta_selector_and(other._pyexpr)) @overload - def write_json(self, file: None = ...) -> str: - ... + def serialize(self, file: None = ...) -> str: ... @overload - def write_json(self, file: IOBase | str | Path) -> None: - ... + def serialize(self, file: IOBase | str | Path) -> None: ... - def write_json(self, file: IOBase | str | Path | None = None) -> str | None: - """Write expression to json.""" + def serialize(self, file: IOBase | str | Path | None = None) -> str | None: + """ + Serialize this expression to a file or string in JSON format. + + Parameters + ---------- + file + File path to which the result should be written. If set to `None` + (default), the output is returned as a string instead. + + See Also + -------- + Expr.deserialize + + Examples + -------- + Serialize the expression into a JSON string. + + >>> expr = pl.col("foo").sum().over("bar") + >>> json = expr.meta.serialize() + >>> json + '{"Window":{"function":{"Agg":{"Sum":{"Column":"foo"}}},"partition_by":[{"Column":"bar"}],"options":{"Over":"GroupsToRows"}}}' + + The expression can later be deserialized back into an `Expr` object. + + >>> from io import StringIO + >>> pl.Expr.deserialize(StringIO(json)) # doctest: +ELLIPSIS + + """ if isinstance(file, (str, Path)): file = normalize_filepath(file) to_string_io = (file is not None) and isinstance(file, StringIO) if file is None or to_string_io: with BytesIO() as buf: - self._pyexpr.meta_write_json(buf) + self._pyexpr.serialize(buf) json_bytes = buf.getvalue() json_str = json_bytes.decode("utf8") @@ -241,16 +267,30 @@ def write_json(self, file: IOBase | str | Path | None = None) -> str | None: else: return json_str else: - self._pyexpr.meta_write_json(file) + self._pyexpr.serialize(file) return None @overload - def tree_format(self, *, return_as_string: Literal[False]) -> None: - ... + def write_json(self, file: None = ...) -> str: ... + + @overload + def write_json(self, file: IOBase | str | Path) -> None: ... + + @deprecate_renamed_function("Expr.meta.serialize", version="0.20.11") + def write_json(self, file: IOBase | str | Path | None = None) -> str | None: + """ + Write expression to json. + + .. deprecated:: 0.20.11 + This method has been renamed to :meth:`serialize`. + """ + return self.serialize(file) + + @overload + def tree_format(self, *, return_as_string: Literal[False]) -> None: ... @overload - def tree_format(self, *, return_as_string: Literal[True]) -> str: - ... + def tree_format(self, *, return_as_string: Literal[True]) -> str: ... @deprecate_nonkeyword_arguments(version="0.19.3") def tree_format(self, return_as_string: bool = False) -> str | None: # noqa: FBT001 diff --git a/py-polars/polars/expr/string.py b/py-polars/polars/expr/string.py index 076d883f54146..00306176c8a80 100644 --- a/py-polars/polars/expr/string.py +++ b/py-polars/polars/expr/string.py @@ -5,18 +5,18 @@ import polars._reexport as pl from polars import functions as F -from polars.datatypes import Date, Datetime, Int32, Time, py_type_to_dtype -from polars.datatypes.constants import N_INFER_DEFAULT -from polars.exceptions import ChronoFormatWarning -from polars.utils._parse_expr_input import parse_as_expression -from polars.utils._wrap import wrap_expr -from polars.utils.deprecation import ( +from polars._utils.deprecation import ( deprecate_renamed_function, deprecate_renamed_parameter, issue_deprecation_warning, rename_use_earliest_to_ambiguous, ) -from polars.utils.various import find_stacklevel +from polars._utils.parse_expr_input import parse_as_expression +from polars._utils.various import find_stacklevel +from polars._utils.wrap import wrap_expr +from polars.datatypes import Date, Datetime, Int32, Time, py_type_to_dtype +from polars.datatypes.constants import N_INFER_DEFAULT +from polars.exceptions import ChronoFormatWarning if TYPE_CHECKING: from polars import Expr @@ -139,6 +139,7 @@ def to_datetime( - `'raise'` (default): raise - `'earliest'`: use the earliest datetime - `'latest'`: use the latest datetime + - `'null'`: set to null Examples -------- @@ -253,6 +254,7 @@ def strptime( - `'raise'` (default): raise - `'earliest'`: use the earliest datetime - `'latest'`: use the latest datetime + - `'null'`: set to null Notes ----- @@ -1346,8 +1348,8 @@ def json_path_match(self, json_path: str) -> Expr: return wrap_expr(self._pyexpr.str_json_path_match(json_path)) def decode(self, encoding: TransferEncoding, *, strict: bool = True) -> Expr: - """ - Decode a value using the provided encoding. + r""" + Decode values using the provided encoding. Parameters ---------- @@ -1356,6 +1358,26 @@ def decode(self, encoding: TransferEncoding, *, strict: bool = True) -> Expr: strict Raise an error if the underlying value cannot be decoded, otherwise mask out with a null value. + + Returns + ------- + Expr + Expression of data type :class:`Binary`. + + Examples + -------- + >>> df = pl.DataFrame({"color": ["000000", "ffff00", "0000ff"]}) + >>> df.with_columns(pl.col("color").str.decode("hex").alias("decoded")) + shape: (3, 2) + ┌────────┬─────────────────┐ + │ color ┆ decoded │ + │ --- ┆ --- │ + │ str ┆ binary │ + ╞════════╪═════════════════╡ + │ 000000 ┆ b"\x00\x00\x00" │ + │ ffff00 ┆ b"\xff\xff\x00" │ + │ 0000ff ┆ b"\x00\x00\xff" │ + └────────┴─────────────────┘ """ if encoding == "hex": return wrap_expr(self._pyexpr.str_hex_decode(strict)) @@ -1367,7 +1389,7 @@ def decode(self, encoding: TransferEncoding, *, strict: bool = True) -> Expr: def encode(self, encoding: TransferEncoding) -> Expr: """ - Encode a value using the provided encoding. + Encode values using the provided encoding. Parameters ---------- diff --git a/py-polars/polars/expr/struct.py b/py-polars/polars/expr/struct.py index a8669b2f33173..709e9b8d16b19 100644 --- a/py-polars/polars/expr/struct.py +++ b/py-polars/polars/expr/struct.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Sequence -from polars.utils._wrap import wrap_expr +from polars._utils.wrap import wrap_expr if TYPE_CHECKING: from polars import Expr diff --git a/py-polars/polars/expr/whenthen.py b/py-polars/polars/expr/whenthen.py index f357289fec465..a45a093768bfc 100644 --- a/py-polars/polars/expr/whenthen.py +++ b/py-polars/polars/expr/whenthen.py @@ -3,12 +3,12 @@ from typing import TYPE_CHECKING, Any, Iterable import polars.functions as F -from polars.expr.expr import Expr -from polars.utils._parse_expr_input import ( +from polars._utils.parse_expr_input import ( parse_as_expression, parse_when_inputs, ) -from polars.utils._wrap import wrap_expr +from polars._utils.wrap import wrap_expr +from polars.expr.expr import Expr if TYPE_CHECKING: from polars.polars import PyExpr diff --git a/py-polars/polars/functions/aggregation/horizontal.py b/py-polars/polars/functions/aggregation/horizontal.py index 6d06aab8162c3..729a50d8de758 100644 --- a/py-polars/polars/functions/aggregation/horizontal.py +++ b/py-polars/polars/functions/aggregation/horizontal.py @@ -4,10 +4,10 @@ from typing import TYPE_CHECKING, Iterable import polars.functions as F +from polars._utils.deprecation import deprecate_renamed_function +from polars._utils.parse_expr_input import parse_as_list_of_expressions +from polars._utils.wrap import wrap_expr from polars.datatypes import UInt32 -from polars.utils._parse_expr_input import parse_as_list_of_expressions -from polars.utils._wrap import wrap_expr -from polars.utils.deprecation import deprecate_renamed_function with contextlib.suppress(ImportError): # Module not available when building docs import polars.polars as plr diff --git a/py-polars/polars/functions/aggregation/vertical.py b/py-polars/polars/functions/aggregation/vertical.py index 16828027f3dd0..ecae33eac1922 100644 --- a/py-polars/polars/functions/aggregation/vertical.py +++ b/py-polars/polars/functions/aggregation/vertical.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING import polars.functions as F -from polars.utils.deprecation import deprecate_renamed_function +from polars._utils.deprecation import deprecate_renamed_function if TYPE_CHECKING: from polars import Expr diff --git a/py-polars/polars/functions/as_datatype.py b/py-polars/polars/functions/as_datatype.py index 02d2208680756..418800f478e43 100644 --- a/py-polars/polars/functions/as_datatype.py +++ b/py-polars/polars/functions/as_datatype.py @@ -4,13 +4,13 @@ from typing import TYPE_CHECKING, Iterable, overload from polars import functions as F -from polars.datatypes import Date, Struct, Time -from polars.utils._parse_expr_input import ( +from polars._utils.deprecation import rename_use_earliest_to_ambiguous +from polars._utils.parse_expr_input import ( parse_as_expression, parse_as_list_of_expressions, ) -from polars.utils._wrap import wrap_expr -from polars.utils.deprecation import rename_use_earliest_to_ambiguous +from polars._utils.wrap import wrap_expr +from polars.datatypes import Date, Struct, Time with contextlib.suppress(ImportError): # Module not available when building docs import polars.polars as plr @@ -75,7 +75,7 @@ def datetime_( - `'raise'` (default): raise - `'earliest'`: use the earliest datetime - `'latest'`: use the latest datetime - + - `'null'`: set to null Returns ------- @@ -363,8 +363,7 @@ def struct( schema: SchemaDict | None = ..., eager: Literal[False] = ..., **named_exprs: IntoExpr, -) -> Expr: - ... +) -> Expr: ... @overload @@ -373,8 +372,7 @@ def struct( schema: SchemaDict | None = ..., eager: Literal[True], **named_exprs: IntoExpr, -) -> Series: - ... +) -> Series: ... @overload @@ -383,8 +381,7 @@ def struct( schema: SchemaDict | None = ..., eager: bool, **named_exprs: IntoExpr, -) -> Expr | Series: - ... +) -> Expr | Series: ... def struct( diff --git a/py-polars/polars/functions/col.py b/py-polars/polars/functions/col.py index da5debc32de4a..90e0fd843ec9f 100644 --- a/py-polars/polars/functions/col.py +++ b/py-polars/polars/functions/col.py @@ -3,8 +3,8 @@ import contextlib from typing import TYPE_CHECKING, Any, Iterable, Protocol, cast +from polars._utils.wrap import wrap_expr from polars.datatypes import is_polars_dtype -from polars.utils._wrap import wrap_expr plr: Any = None with contextlib.suppress(ImportError): # Module not available when building docs @@ -73,11 +73,9 @@ def __call__( self, name: str | PolarsDataType | Iterable[str] | Iterable[PolarsDataType], *more_names: str | PolarsDataType, - ) -> Expr: - ... + ) -> Expr: ... - def __getattr__(self, name: str) -> Expr: - ... + def __getattr__(self, name: str) -> Expr: ... # handle attribute lookup on the metaclass (we use the factory uninstantiated) diff --git a/py-polars/polars/functions/eager.py b/py-polars/polars/functions/eager.py index 573fac1912c10..dc47f8323dbd7 100644 --- a/py-polars/polars/functions/eager.py +++ b/py-polars/polars/functions/eager.py @@ -7,10 +7,10 @@ import polars._reexport as pl from polars import functions as F +from polars._utils.various import ordered_unique +from polars._utils.wrap import wrap_df, wrap_expr, wrap_ldf, wrap_s from polars.exceptions import InvalidOperationError from polars.type_aliases import ConcatMethod, FrameType -from polars.utils._wrap import wrap_df, wrap_expr, wrap_ldf, wrap_s -from polars.utils.various import ordered_unique with contextlib.suppress(ImportError): # Module not available when building docs import polars.polars as plr diff --git a/py-polars/polars/functions/lazy.py b/py-polars/polars/functions/lazy.py index b5e3d5c162de0..20cb0c62d080f 100644 --- a/py-polars/polars/functions/lazy.py +++ b/py-polars/polars/functions/lazy.py @@ -5,19 +5,19 @@ import polars._reexport as pl import polars.functions as F -from polars.datatypes import DTYPE_TEMPORAL_UNITS, Date, Datetime, Int64, UInt32 -from polars.utils._async import _AioDataFrameResult, _GeventDataFrameResult -from polars.utils._parse_expr_input import ( - parse_as_expression, - parse_as_list_of_expressions, -) -from polars.utils._wrap import wrap_df, wrap_expr -from polars.utils.deprecation import ( +from polars._utils.async_ import _AioDataFrameResult, _GeventDataFrameResult +from polars._utils.deprecation import ( deprecate_parameter_as_positional, deprecate_renamed_function, issue_deprecation_warning, ) -from polars.utils.unstable import issue_unstable_warning, unstable +from polars._utils.parse_expr_input import ( + parse_as_expression, + parse_as_list_of_expressions, +) +from polars._utils.unstable import issue_unstable_warning, unstable +from polars._utils.wrap import wrap_df, wrap_expr +from polars.datatypes import DTYPE_TEMPORAL_UNITS, Date, Datetime, Int64, UInt32 with contextlib.suppress(ImportError): # Module not available when building docs import polars.polars as plr @@ -97,8 +97,8 @@ def count(*columns: str) -> Expr: This function is syntactic sugar for `col(columns).count()`. Calling this function without any arguments returns the number of rows in the - context. **This way of using the function is deprecated. Please use :func:`len` - instead.** + context. **This way of using the function is deprecated.** Please use :func:`len` + instead. Parameters ---------- @@ -146,7 +146,7 @@ def count(*columns: str) -> Expr: └─────┴─────┘ Return the number of rows in a context. **This way of using the function is - deprecated. Please use :func:`len` instead.** + deprecated.** Please use :func:`len` instead. >>> df.select(pl.count()) # doctest: +SKIP shape: (1, 1) @@ -1706,8 +1706,7 @@ def collect_all_async( comm_subplan_elim: bool = True, comm_subexpr_elim: bool = True, streaming: bool = True, -) -> _GeventDataFrameResult[list[DataFrame]]: - ... +) -> _GeventDataFrameResult[list[DataFrame]]: ... @overload @@ -1724,8 +1723,7 @@ def collect_all_async( comm_subplan_elim: bool = True, comm_subexpr_elim: bool = True, streaming: bool = False, -) -> Awaitable[list[DataFrame]]: - ... +) -> Awaitable[list[DataFrame]]: ... @unstable() @@ -1881,18 +1879,15 @@ def select(*exprs: IntoExpr | Iterable[IntoExpr], **named_exprs: IntoExpr) -> Da @overload -def arg_where(condition: Expr | Series, *, eager: Literal[False] = ...) -> Expr: - ... +def arg_where(condition: Expr | Series, *, eager: Literal[False] = ...) -> Expr: ... @overload -def arg_where(condition: Expr | Series, *, eager: Literal[True]) -> Series: - ... +def arg_where(condition: Expr | Series, *, eager: Literal[True]) -> Series: ... @overload -def arg_where(condition: Expr | Series, *, eager: bool) -> Expr | Series: - ... +def arg_where(condition: Expr | Series, *, eager: bool) -> Expr | Series: ... def arg_where(condition: Expr | Series, *, eager: bool = False) -> Expr | Series: @@ -1990,15 +1985,13 @@ def coalesce(exprs: IntoExpr | Iterable[IntoExpr], *more_exprs: IntoExpr) -> Exp @overload -def from_epoch(column: str | Expr, time_unit: EpochTimeUnit = ...) -> Expr: - ... +def from_epoch(column: str | Expr, time_unit: EpochTimeUnit = ...) -> Expr: ... @overload def from_epoch( column: Series | Sequence[int], time_unit: EpochTimeUnit = ... -) -> Series: - ... +) -> Series: ... def from_epoch( @@ -2159,8 +2152,7 @@ def sql_expr(sql: str) -> Expr: # type: ignore[overload-overlap] @overload -def sql_expr(sql: Sequence[str]) -> list[Expr]: - ... +def sql_expr(sql: Sequence[str]) -> list[Expr]: ... def sql_expr(sql: str | Sequence[str]) -> Expr | list[Expr]: diff --git a/py-polars/polars/functions/len.py b/py-polars/polars/functions/len.py index f34a3e84cbe2c..44a8ab642a8bf 100644 --- a/py-polars/polars/functions/len.py +++ b/py-polars/polars/functions/len.py @@ -3,12 +3,13 @@ Keep this function in its own module to avoid conflicts with Python's built-in `len`. """ + from __future__ import annotations import contextlib from typing import TYPE_CHECKING -from polars.utils._wrap import wrap_expr +from polars._utils.wrap import wrap_expr with contextlib.suppress(ImportError): # Module not available when building docs import polars.polars as plr diff --git a/py-polars/polars/functions/lit.py b/py-polars/polars/functions/lit.py index 83075f5f317e3..b636d5f2544a6 100644 --- a/py-polars/polars/functions/lit.py +++ b/py-polars/polars/functions/lit.py @@ -5,15 +5,16 @@ from typing import TYPE_CHECKING, Any import polars._reexport as pl +from polars._utils.convert import ( + date_to_int, + datetime_to_int, + time_to_int, + timedelta_to_int, +) +from polars._utils.wrap import wrap_expr from polars.datatypes import Date, Datetime, Duration, Time from polars.dependencies import _check_for_numpy from polars.dependencies import numpy as np -from polars.utils._wrap import wrap_expr -from polars.utils.convert import ( - _datetime_to_pl_timestamp, - _time_to_pl_time, - _timedelta_to_pl_timedelta, -) with contextlib.suppress(ImportError): # Module not available when building docs import polars.polars as plr @@ -35,7 +36,8 @@ def lit( value Value that should be used as a `literal`. dtype - Optionally define a dtype. + The data type of the resulting expression. + If set to `None` (default), the data type is inferred from the `value` input. allow_object If type is unknown use an 'object' type. By default, we will raise a `ValueException` @@ -43,7 +45,7 @@ def lit( Notes ----- - Expected datatypes + Expected datatypes: - `pl.lit([])` -> empty Series Float32 - `pl.lit([1, 2, 3])` -> Series Int64 @@ -75,41 +77,44 @@ def lit( time_unit: TimeUnit if isinstance(value, datetime): - time_unit = "us" if dtype is None else getattr(dtype, "time_unit", "us") - time_zone = ( - value.tzinfo - if getattr(dtype, "time_zone", None) is None - else getattr(dtype, "time_zone", None) - ) - if ( - value.tzinfo is not None - and getattr(dtype, "time_zone", None) is not None - and dtype.time_zone != str(value.tzinfo) # type: ignore[union-attr] - ): - msg = f"time zone of dtype ({dtype.time_zone!r}) differs from time zone of value ({value.tzinfo!r})" # type: ignore[union-attr] - raise TypeError(msg) - e = lit( - _datetime_to_pl_timestamp(value.replace(tzinfo=timezone.utc), time_unit) - ).cast(Datetime(time_unit)) + if dtype is not None and (tu := getattr(dtype, "time_unit", "us")) is not None: + time_unit = tu # type: ignore[assignment] + else: + time_unit = "us" + + time_zone: str | None = getattr(dtype, "time_zone", None) + if (tzinfo := value.tzinfo) is not None: + tzinfo_str = str(tzinfo) + if time_zone is not None and time_zone != tzinfo_str: + msg = f"time zone of dtype ({time_zone!r}) differs from time zone of value ({tzinfo!r})" + raise TypeError(msg) + time_zone = tzinfo_str + + dt_utc = value.replace(tzinfo=timezone.utc) + dt_int = datetime_to_int(dt_utc, time_unit) + expr = lit(dt_int).cast(Datetime(time_unit)) if time_zone is not None: - return e.dt.replace_time_zone( - str(time_zone), ambiguous="earliest" if value.fold == 0 else "latest" + expr = expr.dt.replace_time_zone( + time_zone, ambiguous="earliest" if value.fold == 0 else "latest" ) - else: - return e + return expr elif isinstance(value, timedelta): - if dtype is None or (time_unit := getattr(dtype, "time_unit", "us")) is None: + if dtype is not None and (tu := getattr(dtype, "time_unit", "us")) is not None: + time_unit = tu # type: ignore[assignment] + else: time_unit = "us" - return lit(_timedelta_to_pl_timedelta(value, time_unit)).cast( - Duration(time_unit) - ) + + td_int = timedelta_to_int(value, time_unit) + return lit(td_int).cast(Duration(time_unit)) elif isinstance(value, time): - return lit(_time_to_pl_time(value)).cast(Time) + time_int = time_to_int(value) + return lit(time_int).cast(Time) elif isinstance(value, date): - return lit(datetime(value.year, value.month, value.day)).cast(Date) + date_int = date_to_int(value) + return lit(date_int).cast(Date) elif isinstance(value, pl.Series): value = value._s diff --git a/py-polars/polars/functions/range/_utils.py b/py-polars/polars/functions/range/_utils.py index da1b38dbbd1c0..86bdeedd15cd6 100644 --- a/py-polars/polars/functions/range/_utils.py +++ b/py-polars/polars/functions/range/_utils.py @@ -2,13 +2,13 @@ from datetime import timedelta -from polars.utils.convert import _timedelta_to_pl_duration +from polars._utils.convert import parse_as_duration_string def parse_interval_argument(interval: str | timedelta) -> str: """Parse the interval argument as a Polars duration string.""" if isinstance(interval, timedelta): - return _timedelta_to_pl_duration(interval) + return parse_as_duration_string(interval) if " " in interval: interval = interval.replace(" ", "") diff --git a/py-polars/polars/functions/range/date_range.py b/py-polars/polars/functions/range/date_range.py index ae6e0af5dea1f..6987d92459091 100644 --- a/py-polars/polars/functions/range/date_range.py +++ b/py-polars/polars/functions/range/date_range.py @@ -5,13 +5,13 @@ from typing import TYPE_CHECKING, overload from polars import functions as F -from polars.functions.range._utils import parse_interval_argument -from polars.utils._parse_expr_input import parse_as_expression -from polars.utils._wrap import wrap_expr -from polars.utils.deprecation import ( +from polars._utils.deprecation import ( deprecate_saturating, issue_deprecation_warning, ) +from polars._utils.parse_expr_input import parse_as_expression +from polars._utils.wrap import wrap_expr +from polars.functions.range._utils import parse_interval_argument with contextlib.suppress(ImportError): # Module not available when building docs import polars.polars as plr @@ -34,8 +34,7 @@ def date_range( time_unit: TimeUnit | None = ..., time_zone: str | None = ..., eager: Literal[False] = ..., -) -> Expr: - ... +) -> Expr: ... @overload @@ -48,8 +47,7 @@ def date_range( time_unit: TimeUnit | None = ..., time_zone: str | None = ..., eager: Literal[True], -) -> Series: - ... +) -> Series: ... @overload @@ -62,8 +60,7 @@ def date_range( time_unit: TimeUnit | None = ..., time_zone: str | None = ..., eager: bool, -) -> Series | Expr: - ... +) -> Series | Expr: ... def date_range( @@ -201,8 +198,7 @@ def date_ranges( time_unit: TimeUnit | None = ..., time_zone: str | None = ..., eager: Literal[False] = ..., -) -> Expr: - ... +) -> Expr: ... @overload @@ -215,8 +211,7 @@ def date_ranges( time_unit: TimeUnit | None = ..., time_zone: str | None = ..., eager: Literal[True], -) -> Series: - ... +) -> Series: ... @overload @@ -229,8 +224,7 @@ def date_ranges( time_unit: TimeUnit | None = ..., time_zone: str | None = ..., eager: bool, -) -> Series | Expr: - ... +) -> Series | Expr: ... def date_ranges( diff --git a/py-polars/polars/functions/range/datetime_range.py b/py-polars/polars/functions/range/datetime_range.py index 76321e59b0a58..9eadbc41c6ccc 100644 --- a/py-polars/polars/functions/range/datetime_range.py +++ b/py-polars/polars/functions/range/datetime_range.py @@ -4,10 +4,10 @@ from typing import TYPE_CHECKING, overload from polars import functions as F +from polars._utils.deprecation import deprecate_saturating +from polars._utils.parse_expr_input import parse_as_expression +from polars._utils.wrap import wrap_expr from polars.functions.range._utils import parse_interval_argument -from polars.utils._parse_expr_input import parse_as_expression -from polars.utils._wrap import wrap_expr -from polars.utils.deprecation import deprecate_saturating with contextlib.suppress(ImportError): # Module not available when building docs import polars.polars as plr @@ -30,8 +30,7 @@ def datetime_range( time_unit: TimeUnit | None = ..., time_zone: str | None = ..., eager: Literal[False] = ..., -) -> Expr: - ... +) -> Expr: ... @overload @@ -44,8 +43,7 @@ def datetime_range( time_unit: TimeUnit | None = ..., time_zone: str | None = ..., eager: Literal[True], -) -> Series: - ... +) -> Series: ... @overload @@ -58,8 +56,7 @@ def datetime_range( time_unit: TimeUnit | None = ..., time_zone: str | None = ..., eager: bool, -) -> Series | Expr: - ... +) -> Series | Expr: ... def datetime_range( @@ -206,8 +203,7 @@ def datetime_ranges( time_unit: TimeUnit | None = ..., time_zone: str | None = ..., eager: Literal[False] = ..., -) -> Expr: - ... +) -> Expr: ... @overload @@ -220,8 +216,7 @@ def datetime_ranges( time_unit: TimeUnit | None = ..., time_zone: str | None = ..., eager: Literal[True], -) -> Series: - ... +) -> Series: ... @overload @@ -234,8 +229,7 @@ def datetime_ranges( time_unit: TimeUnit | None = ..., time_zone: str | None = ..., eager: bool, -) -> Series | Expr: - ... +) -> Series | Expr: ... def datetime_ranges( diff --git a/py-polars/polars/functions/range/int_range.py b/py-polars/polars/functions/range/int_range.py index c23e0196a8e25..35cb691f5aee5 100644 --- a/py-polars/polars/functions/range/int_range.py +++ b/py-polars/polars/functions/range/int_range.py @@ -4,9 +4,9 @@ from typing import TYPE_CHECKING, overload from polars import functions as F +from polars._utils.parse_expr_input import parse_as_expression +from polars._utils.wrap import wrap_expr, wrap_s from polars.datatypes import Int64 -from polars.utils._parse_expr_input import parse_as_expression -from polars.utils._wrap import wrap_expr, wrap_s with contextlib.suppress(ImportError): # Module not available when building docs import polars.polars as plr @@ -26,8 +26,7 @@ def arange( *, dtype: PolarsIntegerType = ..., eager: Literal[False] = ..., -) -> Expr: - ... +) -> Expr: ... @overload @@ -38,8 +37,7 @@ def arange( *, dtype: PolarsIntegerType = ..., eager: Literal[True], -) -> Series: - ... +) -> Series: ... @overload @@ -50,8 +48,7 @@ def arange( *, dtype: PolarsIntegerType = ..., eager: bool, -) -> Expr | Series: - ... +) -> Expr | Series: ... def arange( @@ -113,8 +110,7 @@ def int_range( *, dtype: PolarsIntegerType = ..., eager: Literal[False] = ..., -) -> Expr: - ... +) -> Expr: ... @overload @@ -125,8 +121,7 @@ def int_range( *, dtype: PolarsIntegerType = ..., eager: Literal[True], -) -> Series: - ... +) -> Series: ... @overload @@ -137,8 +132,7 @@ def int_range( *, dtype: PolarsIntegerType = ..., eager: bool, -) -> Expr | Series: - ... +) -> Expr | Series: ... def int_range( @@ -241,8 +235,7 @@ def int_ranges( *, dtype: PolarsIntegerType = ..., eager: Literal[False] = ..., -) -> Expr: - ... +) -> Expr: ... @overload @@ -253,8 +246,7 @@ def int_ranges( *, dtype: PolarsIntegerType = ..., eager: Literal[True], -) -> Series: - ... +) -> Series: ... @overload @@ -265,8 +257,7 @@ def int_ranges( *, dtype: PolarsIntegerType = ..., eager: bool, -) -> Expr | Series: - ... +) -> Expr | Series: ... def int_ranges( diff --git a/py-polars/polars/functions/range/time_range.py b/py-polars/polars/functions/range/time_range.py index 5563e072b357e..ed00289f90265 100644 --- a/py-polars/polars/functions/range/time_range.py +++ b/py-polars/polars/functions/range/time_range.py @@ -5,10 +5,10 @@ from typing import TYPE_CHECKING, overload from polars import functions as F +from polars._utils.deprecation import deprecate_saturating +from polars._utils.parse_expr_input import parse_as_expression +from polars._utils.wrap import wrap_expr from polars.functions.range._utils import parse_interval_argument -from polars.utils._parse_expr_input import parse_as_expression -from polars.utils._wrap import wrap_expr -from polars.utils.deprecation import deprecate_saturating with contextlib.suppress(ImportError): # Module not available when building docs import polars.polars as plr @@ -29,8 +29,7 @@ def time_range( *, closed: ClosedInterval = ..., eager: Literal[False] = ..., -) -> Expr: - ... +) -> Expr: ... @overload @@ -41,8 +40,7 @@ def time_range( *, closed: ClosedInterval = ..., eager: Literal[True], -) -> Series: - ... +) -> Series: ... @overload @@ -53,8 +51,7 @@ def time_range( *, closed: ClosedInterval = ..., eager: bool, -) -> Series | Expr: - ... +) -> Series | Expr: ... def time_range( @@ -166,8 +163,7 @@ def time_ranges( *, closed: ClosedInterval = ..., eager: Literal[False] = ..., -) -> Expr: - ... +) -> Expr: ... @overload @@ -178,8 +174,7 @@ def time_ranges( *, closed: ClosedInterval = ..., eager: Literal[True], -) -> Series: - ... +) -> Series: ... @overload @@ -190,8 +185,7 @@ def time_ranges( *, closed: ClosedInterval = ..., eager: bool, -) -> Series | Expr: - ... +) -> Series | Expr: ... def time_ranges( diff --git a/py-polars/polars/functions/repeat.py b/py-polars/polars/functions/repeat.py index c922a5a055031..5f76a1a043b26 100644 --- a/py-polars/polars/functions/repeat.py +++ b/py-polars/polars/functions/repeat.py @@ -6,6 +6,8 @@ from typing import TYPE_CHECKING, Any, overload from polars import functions as F +from polars._utils.parse_expr_input import parse_as_expression +from polars._utils.wrap import wrap_expr from polars.datatypes import ( FLOAT_DTYPES, INTEGER_DTYPES, @@ -16,8 +18,6 @@ List, Utf8, ) -from polars.utils._parse_expr_input import parse_as_expression -from polars.utils._wrap import wrap_expr with contextlib.suppress(ImportError): # Module not available when building docs import polars.polars as plr @@ -57,8 +57,7 @@ def repeat( *, dtype: PolarsDataType | None = ..., eager: Literal[False] = ..., -) -> Expr: - ... +) -> Expr: ... @overload @@ -68,8 +67,7 @@ def repeat( *, dtype: PolarsDataType | None = ..., eager: Literal[True], -) -> Series: - ... +) -> Series: ... @overload @@ -79,8 +77,7 @@ def repeat( *, dtype: PolarsDataType | None = ..., eager: bool, -) -> Expr | Series: - ... +) -> Expr | Series: ... def repeat( @@ -155,8 +152,7 @@ def ones( dtype: PolarsDataType = ..., *, eager: Literal[False] = ..., -) -> Expr: - ... +) -> Expr: ... @overload @@ -165,8 +161,7 @@ def ones( dtype: PolarsDataType = ..., *, eager: Literal[True], -) -> Series: - ... +) -> Series: ... @overload @@ -175,8 +170,7 @@ def ones( dtype: PolarsDataType = ..., *, eager: bool, -) -> Expr | Series: - ... +) -> Expr | Series: ... def ones( @@ -234,8 +228,7 @@ def zeros( dtype: PolarsDataType = ..., *, eager: Literal[False] = ..., -) -> Expr: - ... +) -> Expr: ... @overload @@ -244,8 +237,7 @@ def zeros( dtype: PolarsDataType = ..., *, eager: Literal[True], -) -> Series: - ... +) -> Series: ... @overload @@ -254,8 +246,7 @@ def zeros( dtype: PolarsDataType = ..., *, eager: bool, -) -> Expr | Series: - ... +) -> Expr | Series: ... def zeros( diff --git a/py-polars/polars/functions/whenthen.py b/py-polars/polars/functions/whenthen.py index 77ab37a09c024..2c6254a346d35 100644 --- a/py-polars/polars/functions/whenthen.py +++ b/py-polars/polars/functions/whenthen.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any, Iterable import polars._reexport as pl -from polars.utils._parse_expr_input import parse_when_inputs +from polars._utils.parse_expr_input import parse_when_inputs with contextlib.suppress(ImportError): # Module not available when building docs import polars.polars as plr diff --git a/py-polars/polars/io/__init__.py b/py-polars/polars/io/__init__.py index f4a39b5f778ab..395f15bd4c94c 100644 --- a/py-polars/polars/io/__init__.py +++ b/py-polars/polars/io/__init__.py @@ -21,8 +21,8 @@ "read_delta", "read_excel", "read_ipc", - "read_ipc_stream", "read_ipc_schema", + "read_ipc_stream", "read_json", "read_ndjson", "read_ods", diff --git a/py-polars/polars/io/_utils.py b/py-polars/polars/io/_utils.py index 1c28c7b045ce7..dac87c04900d5 100644 --- a/py-polars/polars/io/_utils.py +++ b/py-polars/polars/io/_utils.py @@ -8,9 +8,9 @@ from tempfile import NamedTemporaryFile from typing import IO, Any, ContextManager, Iterator, cast, overload +from polars._utils.various import normalize_filepath from polars.dependencies import _FSSPEC_AVAILABLE, fsspec from polars.exceptions import NoDataError -from polars.utils.various import normalize_filepath def _is_glob_pattern(file: str) -> bool: @@ -32,14 +32,13 @@ def _is_local_file(file: str) -> bool: @overload def _prepare_file_arg( - file: str | list[str] | Path | IO[bytes] | bytes, + file: str | Path | list[str] | IO[bytes] | bytes, encoding: str | None = ..., *, use_pyarrow: bool = ..., raise_if_empty: bool = ..., storage_options: dict[str, Any] | None = ..., -) -> ContextManager[str | BytesIO]: - ... +) -> ContextManager[str | BytesIO]: ... @overload @@ -50,24 +49,22 @@ def _prepare_file_arg( use_pyarrow: bool = ..., raise_if_empty: bool = ..., storage_options: dict[str, Any] | None = ..., -) -> ContextManager[str | BytesIO]: - ... +) -> ContextManager[str | BytesIO]: ... @overload def _prepare_file_arg( - file: str | list[str] | Path | IO[str] | IO[bytes] | bytes, + file: str | Path | list[str] | IO[str] | IO[bytes] | bytes, encoding: str | None = ..., *, use_pyarrow: bool = ..., raise_if_empty: bool = ..., storage_options: dict[str, Any] | None = ..., -) -> ContextManager[str | list[str] | BytesIO | list[BytesIO]]: - ... +) -> ContextManager[str | list[str] | BytesIO | list[BytesIO]]: ... def _prepare_file_arg( - file: str | list[str] | Path | IO[str] | IO[bytes] | bytes, + file: str | Path | list[str] | IO[str] | IO[bytes] | bytes, encoding: str | None = None, *, use_pyarrow: bool = False, diff --git a/py-polars/polars/io/avro.py b/py-polars/polars/io/avro.py index e93667ee00aec..a25be704c4f2d 100644 --- a/py-polars/polars/io/avro.py +++ b/py-polars/polars/io/avro.py @@ -1,18 +1,17 @@ from __future__ import annotations -from typing import TYPE_CHECKING, BinaryIO +from typing import IO, TYPE_CHECKING import polars._reexport as pl if TYPE_CHECKING: - from io import BytesIO from pathlib import Path from polars import DataFrame def read_avro( - source: str | Path | BytesIO | BinaryIO, + source: str | Path | IO[bytes] | bytes, *, columns: list[int] | list[str] | None = None, n_rows: int | None = None, diff --git a/py-polars/polars/io/csv/__init__.py b/py-polars/polars/io/csv/__init__.py index b18232f103468..cf5a2646240d7 100644 --- a/py-polars/polars/io/csv/__init__.py +++ b/py-polars/polars/io/csv/__init__.py @@ -1,6 +1,8 @@ +from polars.io.csv.batched_reader import BatchedCsvReader from polars.io.csv.functions import read_csv, read_csv_batched, scan_csv __all__ = [ + "BatchedCsvReader", "read_csv", "read_csv_batched", "scan_csv", diff --git a/py-polars/polars/io/csv/batched_reader.py b/py-polars/polars/io/csv/batched_reader.py index 84ad7ba57b094..101672f7a5e93 100644 --- a/py-polars/polars/io/csv/batched_reader.py +++ b/py-polars/polars/io/csv/batched_reader.py @@ -1,23 +1,24 @@ from __future__ import annotations import contextlib -from pathlib import Path from typing import TYPE_CHECKING, Sequence -from polars.datatypes import N_INFER_DEFAULT, py_type_to_dtype -from polars.io.csv._utils import _update_columns -from polars.utils._wrap import wrap_df -from polars.utils.various import ( +from polars._utils.various import ( _prepare_row_index_args, _process_null_values, handle_projection_columns, normalize_filepath, ) +from polars._utils.wrap import wrap_df +from polars.datatypes import N_INFER_DEFAULT, py_type_to_dtype +from polars.io.csv._utils import _update_columns with contextlib.suppress(ImportError): # Module not available when building docs from polars.polars import PyBatchedCsv if TYPE_CHECKING: + from pathlib import Path + from polars import DataFrame from polars.type_aliases import CsvEncoding, PolarsDataType, SchemaDict @@ -56,9 +57,7 @@ def __init__( raise_if_empty: bool = True, truncate_ragged_lines: bool = False, ): - path: str | None - if isinstance(source, (str, Path)): - path = normalize_filepath(source) + path = normalize_filepath(source) dtype_list: Sequence[tuple[str, PolarsDataType]] | None = None dtype_slice: Sequence[PolarsDataType] | None = None @@ -111,14 +110,12 @@ def next_batches(self, n: int) -> list[DataFrame] | None: """ Read `n` batches from the reader. - The `n` chunks will be parallelized over the - available threads. + These batches will be parallelized over the available threads. Parameters ---------- n - Number of chunks to fetch. - This is ideally >= number of threads + Number of chunks to fetch; ideally this is >= number of threads. Examples -------- diff --git a/py-polars/polars/io/csv/functions.py b/py-polars/polars/io/csv/functions.py index b82400caf25c6..9291345a9a600 100644 --- a/py-polars/polars/io/csv/functions.py +++ b/py-polars/polars/io/csv/functions.py @@ -1,19 +1,17 @@ from __future__ import annotations from pathlib import Path -from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Mapping, Sequence, TextIO +from typing import IO, TYPE_CHECKING, Any, Callable, Mapping, Sequence import polars._reexport as pl +from polars._utils.deprecation import deprecate_renamed_parameter +from polars._utils.various import handle_projection_columns, normalize_filepath from polars.datatypes import N_INFER_DEFAULT, String from polars.io._utils import _prepare_file_arg from polars.io.csv._utils import _check_arg_is_1byte, _update_columns from polars.io.csv.batched_reader import BatchedCsvReader -from polars.utils.deprecation import deprecate_renamed_parameter -from polars.utils.various import handle_projection_columns, normalize_filepath if TYPE_CHECKING: - from io import BytesIO - from polars import DataFrame, LazyFrame from polars.type_aliases import CsvEncoding, PolarsDataType, SchemaDict @@ -24,7 +22,7 @@ old_name="comment_char", new_name="comment_prefix", version="0.19.14" ) def read_csv( - source: str | TextIO | BytesIO | Path | BinaryIO | bytes, + source: str | Path | IO[str] | IO[bytes] | bytes, *, has_header: bool = True, columns: Sequence[int] | Sequence[str] | None = None, @@ -101,6 +99,7 @@ def read_csv( - `List[str]`: All values equal to any string in this list will be null. - `Dict[str, str]`: A dictionary that maps column name to a null value string. + missing_utf8_is_empty_string By default a missing value is considered to be null; if you would prefer missing utf8 values to be treated as the empty string you can set this param True. @@ -509,6 +508,7 @@ def read_csv_batched( - `List[str]`: All values equal to any string in this list will be null. - `Dict[str, str]`: A dictionary that maps column name to a null value string. + missing_utf8_is_empty_string By default a missing value is considered to be null; if you would prefer missing utf8 values to be treated as the empty string you can set this param True. @@ -802,6 +802,7 @@ def scan_csv( - `List[str]`: All values equal to any string in this list will be null. - `Dict[str, str]`: A dictionary that maps column name to a null value string. + missing_utf8_is_empty_string By default a missing value is considered to be null; if you would prefer missing utf8 values to be treated as the empty string you can set this param True. diff --git a/py-polars/polars/io/database.py b/py-polars/polars/io/database.py index 17e020cf7a3e6..f7e933b1e448e 100644 --- a/py-polars/polars/io/database.py +++ b/py-polars/polars/io/database.py @@ -6,13 +6,16 @@ from inspect import Parameter, signature from typing import TYPE_CHECKING, Any, Iterable, Literal, Sequence, TypedDict, overload +from polars._utils.deprecation import issue_deprecation_warning from polars.convert import from_arrow +from polars.datatypes import N_INFER_DEFAULT from polars.exceptions import InvalidOperationError, UnsuitableSQLError -from polars.utils.deprecation import issue_deprecation_warning if TYPE_CHECKING: from types import TracebackType + import pyarrow as pa + if sys.version_info >= (3, 10): from typing import TypeAlias else: @@ -23,7 +26,6 @@ from typing_extensions import Self from polars import DataFrame - from polars.dependencies import pyarrow as pa from polars.type_aliases import ConnectionOrCursor, Cursor, DbReadEngine, SchemaDict try: @@ -33,10 +35,15 @@ class _ArrowDriverProperties_(TypedDict): - fetch_all: str # name of the method that fetches all arrow data - fetch_batches: str | None # name of the method that fetches arrow data in batches - exact_batch_size: bool | None # whether indicated batch size is respected exactly - repeat_batch_calls: bool # repeat batch calls (if False, batch call is generator) + # name of the method that fetches all arrow data; tuple form + # calls the fetch_all method with the given chunk size (int) + fetch_all: str | tuple[str, int] + # name of the method that fetches arrow data in batches + fetch_batches: str | None + # indicate whether the given batch size is respected exactly + exact_batch_size: bool | None + # repeat batch calls (if False, the batch call is a generator) + repeat_batch_calls: bool _ARROW_DRIVER_REGISTRY_: dict[str, _ArrowDriverProperties_] = { @@ -47,7 +54,7 @@ class _ArrowDriverProperties_(TypedDict): "repeat_batch_calls": False, }, "arrow_odbc_proxy": { - "fetch_all": "fetch_record_batches", + "fetch_all": "fetch_arrow_table", "fetch_batches": "fetch_record_batches", "exact_batch_size": True, "repeat_batch_calls": False, @@ -64,6 +71,13 @@ class _ArrowDriverProperties_(TypedDict): "exact_batch_size": True, "repeat_batch_calls": False, }, + "kuzu": { + # 'get_as_arrow' currently takes a mandatory chunk size + "fetch_all": ("get_as_arrow", 10_000), + "fetch_batches": None, + "exact_batch_size": None, + "repeat_batch_calls": False, + }, "snowflake": { "fetch_all": "fetch_arrow_all", "fetch_batches": "fetch_arrow_batches", @@ -109,21 +123,41 @@ def execute(self, query: str, **execute_options: Any) -> None: self.execute_options = execute_options self.query = query + def fetch_arrow_table( + self, batch_size: int = 10_000, *, fetch_all: bool = False + ) -> pa.Table: + """Fetch all results as a pyarrow Table.""" + from pyarrow import Table + + return Table.from_batches( + self.fetch_record_batches(batch_size=batch_size, fetch_all=True) + ) + def fetch_record_batches( - self, batch_size: int = 10_000 + self, batch_size: int = 10_000, *, fetch_all: bool = False ) -> Iterable[pa.RecordBatch]: - """Fetch results in batches.""" + """Fetch results as an iterable of RecordBatches.""" from arrow_odbc import read_arrow_batches_from_odbc + from pyarrow import RecordBatch - yield from read_arrow_batches_from_odbc( + n_batches = 0 + batch_reader = read_arrow_batches_from_odbc( query=self.query, batch_size=batch_size, connection_string=self.connection_string, **self.execute_options, ) + for batch in batch_reader: + yield batch + n_batches += 1 + + if n_batches == 0 and fetch_all: + # empty result set; return empty batch with accurate schema + yield RecordBatch.from_pylist([], schema=batch_reader.schema) - # internally arrow-odbc always reads batches - fetchall = fetchmany = fetch_record_batches + # note: internally arrow-odbc always reads batches + fetchall = fetch_arrow_table + fetchmany = fetch_record_batches class ConnectionExecutor: @@ -153,24 +187,27 @@ def __exit__( ) -> None: # if we created it and are finished with it, we can # close the cursor (but NOT the connection) - if self.can_close_cursor: + if self.can_close_cursor and hasattr(self.cursor, "close"): self.cursor.close() def __repr__(self) -> str: return f"<{type(self).__name__} module={self.driver_name!r}>" - def _arrow_batches( + def _fetch_arrow( self, driver_properties: _ArrowDriverProperties_, *, batch_size: int | None, iter_batches: bool, ) -> Iterable[pa.RecordBatch]: - """Yield Arrow data in batches, or as a single 'fetchall' batch.""" + """Yield Arrow data as a generator of one or more RecordBatches or Tables.""" fetch_batches = driver_properties["fetch_batches"] if not iter_batches or fetch_batches is None: - fetch_method = driver_properties["fetch_all"] - yield getattr(self.result, fetch_method)() + fetch_method, sz = driver_properties["fetch_all"], [] + if isinstance(fetch_method, tuple): + fetch_method, chunk_size = fetch_method + sz = [chunk_size] + yield getattr(self.result, fetch_method)(*sz) else: size = batch_size if driver_properties["exact_batch_size"] else None repeat_batch_calls = driver_properties["repeat_batch_calls"] @@ -184,31 +221,6 @@ def _arrow_batches( break yield arrow - def _normalise_cursor(self, conn: ConnectionOrCursor) -> Cursor: - """Normalise a connection object such that we have the query executor.""" - if self.driver_name == "sqlalchemy" and type(conn).__name__ == "Engine": - self.can_close_cursor = True - if conn.driver == "databricks-sql-python": # type: ignore[union-attr] - # take advantage of the raw connection to get arrow integration - self.driver_name = "databricks" - return conn.raw_connection().cursor() # type: ignore[union-attr, return-value] - else: - # sqlalchemy engine; direct use is deprecated, so prefer the connection - return conn.connect() # type: ignore[union-attr, return-value] - - elif hasattr(conn, "cursor"): - # connection has a dedicated cursor; prefer over direct execute - cursor = cursor() if callable(cursor := conn.cursor) else cursor - self.can_close_cursor = True - return cursor - - elif hasattr(conn, "execute"): - # can execute directly (given cursor, sqlalchemy connection, etc) - return conn # type: ignore[return-value] - - msg = f"Unrecognised connection {conn!r}; unable to find 'execute' method" - raise TypeError(msg) - @staticmethod def _fetchall_rows(result: Cursor) -> Iterable[Sequence[Any]]: """Fetch row data in a single call, returning the complete result set.""" @@ -238,6 +250,7 @@ def _from_arrow( batch_size: int | None, iter_batches: bool, schema_overrides: SchemaDict | None, + infer_schema_length: int | None, ) -> DataFrame | Iterable[DataFrame] | None: """Return resultset data in Arrow format for frame init.""" from polars import from_arrow @@ -249,7 +262,7 @@ def _from_arrow( self.can_close_cursor = fetch_batches is None or not iter_batches frames = ( from_arrow(batch, schema_overrides=schema_overrides) - for batch in self._arrow_batches( + for batch in self._fetch_arrow( driver_properties, iter_batches=iter_batches, batch_size=batch_size, @@ -274,26 +287,36 @@ def _from_rows( batch_size: int | None, iter_batches: bool, schema_overrides: SchemaDict | None, + infer_schema_length: int | None, ) -> DataFrame | Iterable[DataFrame] | None: """Return resultset data row-wise for frame init.""" from polars import DataFrame if hasattr(self.result, "fetchall"): - description = ( - self.result.cursor.description - if self.driver_name == "sqlalchemy" - else self.result.description - ) - column_names = [desc[0] for desc in description] + if self.driver_name == "sqlalchemy": + if hasattr(self.result, "cursor"): + cursor_desc = {d[0]: d[1] for d in self.result.cursor.description} + elif hasattr(self.result, "_metadata"): + cursor_desc = {k: None for k in self.result._metadata.keys} + else: + msg = f"Unable to determine metadata from query result; {self.result!r}" + raise ValueError(msg) + else: + cursor_desc = {d[0]: d[1] for d in self.result.description} + + # TODO: refine types based on the cursor description's type_code, + # if/where available? (for now, we just read the column names) + result_columns = list(cursor_desc) frames = ( DataFrame( data=rows, - schema=column_names, + schema=result_columns, schema_overrides=schema_overrides, + infer_schema_length=infer_schema_length, orient="row", ) for rows in ( - self._fetchmany_rows(self.result, batch_size) + list(self._fetchmany_rows(self.result, batch_size)) if iter_batches else [self._fetchall_rows(self.result)] # type: ignore[list-item] ) @@ -301,6 +324,31 @@ def _from_rows( return frames if iter_batches else next(frames) # type: ignore[arg-type] return None + def _normalise_cursor(self, conn: ConnectionOrCursor) -> Cursor: + """Normalise a connection object such that we have the query executor.""" + if self.driver_name == "sqlalchemy" and type(conn).__name__ == "Engine": + self.can_close_cursor = True + if conn.driver == "databricks-sql-python": # type: ignore[union-attr] + # take advantage of the raw connection to get arrow integration + self.driver_name = "databricks" + return conn.raw_connection().cursor() # type: ignore[union-attr, return-value] + else: + # sqlalchemy engine; direct use is deprecated, so prefer the connection + return conn.connect() # type: ignore[union-attr, return-value] + + elif hasattr(conn, "cursor"): + # connection has a dedicated cursor; prefer over direct execute + cursor = cursor() if callable(cursor := conn.cursor) else cursor + self.can_close_cursor = True + return cursor + + elif hasattr(conn, "execute"): + # can execute directly (given cursor, sqlalchemy connection, etc) + return conn # type: ignore[return-value] + + msg = f"Unrecognised connection {conn!r}; unable to find 'execute' method" + raise TypeError(msg) + def execute( self, query: str | Selectable, @@ -318,18 +366,33 @@ def execute( options = options or {} cursor_execute = self.cursor.execute - if self.driver_name == "sqlalchemy" and isinstance(query, str): - params = options.get("parameters") - if isinstance(params, Sequence) and hasattr(self.cursor, "exec_driver_sql"): - cursor_execute = self.cursor.exec_driver_sql - if isinstance(params, list) and not all( - isinstance(p, (dict, tuple)) for p in params + if self.driver_name == "sqlalchemy": + from sqlalchemy.orm import Session + + param_key = "parameters" + if ( + isinstance(self.cursor, Session) + and "parameters" in options + and "params" not in options + ): + options = options.copy() + options["params"] = options.pop("parameters") + param_key = "params" + + if isinstance(query, str): + params = options.get(param_key) + if isinstance(params, Sequence) and hasattr( + self.cursor, "exec_driver_sql" ): - options["parameters"] = tuple(params) - else: - from sqlalchemy.sql import text + cursor_execute = self.cursor.exec_driver_sql + if isinstance(params, list) and not all( + isinstance(p, (dict, tuple)) for p in params + ): + options[param_key] = tuple(params) + else: + from sqlalchemy.sql import text - query = text(query) # type: ignore[assignment] + query = text(query) # type: ignore[assignment] # note: some cursor execute methods (eg: sqlite3) only take positional # params, hence the slightly convoluted resolution of the 'options' dict @@ -360,6 +423,7 @@ def to_polars( iter_batches: bool = False, batch_size: int | None = None, schema_overrides: SchemaDict | None = None, + infer_schema_length: int | None = N_INFER_DEFAULT, ) -> DataFrame | Iterable[DataFrame]: """ Convert the result set to a DataFrame. @@ -384,6 +448,7 @@ def to_polars( batch_size=batch_size, iter_batches=iter_batches, schema_overrides=schema_overrides, + infer_schema_length=infer_schema_length, ) if frame is not None: return frame @@ -402,9 +467,10 @@ def read_database( iter_batches: Literal[False] = False, batch_size: int | None = ..., schema_overrides: SchemaDict | None = ..., + infer_schema_length: int | None = ..., + execute_options: dict[str, Any] | None = ..., **kwargs: Any, -) -> DataFrame: - ... +) -> DataFrame: ... @overload @@ -415,9 +481,10 @@ def read_database( iter_batches: Literal[True], batch_size: int | None = ..., schema_overrides: SchemaDict | None = ..., + infer_schema_length: int | None = ..., + execute_options: dict[str, Any] | None = ..., **kwargs: Any, -) -> Iterable[DataFrame]: - ... +) -> Iterable[DataFrame]: ... def read_database( # noqa: D417 @@ -427,6 +494,7 @@ def read_database( # noqa: D417 iter_batches: bool = False, batch_size: int | None = None, schema_overrides: SchemaDict | None = None, + infer_schema_length: int | None = N_INFER_DEFAULT, execute_options: dict[str, Any] | None = None, **kwargs: Any, ) -> DataFrame | Iterable[DataFrame]: @@ -440,9 +508,10 @@ def read_database( # noqa: D417 be a suitable "Selectable", otherwise it is expected to be a string). connection An instantiated connection (or cursor/client object) that the query can be - executed against. Can also pass a valid ODBC connection string, starting with - "Driver=", in which case the `arrow-odbc` package will be used to establish - the connection and return Arrow-native data to Polars. + executed against. Can also pass a valid ODBC connection string, identified as + such if it contains the string "Driver=", in which case the `arrow-odbc` + package will be used to establish the connection and return Arrow-native data + to Polars. iter_batches Return an iterator of DataFrames, where each DataFrame represents a batch of data returned by the query; this can be useful for processing large resultsets @@ -464,6 +533,11 @@ def read_database( # noqa: D417 on driver/backend). This can be useful if the given types can be more precisely defined (for example, if you know that a given column can be declared as `u32` instead of `i64`). + infer_schema_length + The maximum number of rows to scan for schema inference. If set to `None`, the + full data may be scanned *(this can be slow)*. This parameter only applies if + the data is read as a sequence of rows and the `schema_overrides` parameter + is not set for the given column; Arrow-aware drivers also ignore this value. execute_options These options will be passed through into the underlying query execution method as kwargs. In the case of connections made using an ODBC string (which use @@ -484,17 +558,20 @@ def read_database( # noqa: D417 more details about using this driver (notable databases implementing Flight SQL include Dremio and InfluxDB). - * The `read_database_uri` function is likely to be noticeably faster than - `read_database` if you are using a SQLAlchemy or DBAPI2 connection, as - `connectorx` will optimise translation of the result set into Arrow format - in Rust, whereas these libraries will return row-wise data to Python *before* - we can load into Arrow. Note that you can easily determine the connection's - URI from a SQLAlchemy engine object by calling + * The `read_database_uri` function can be noticeably faster than `read_database` + if you are using a SQLAlchemy or DBAPI2 connection, as `connectorx` optimises + translation of the result set into Arrow format in Rust, whereas these libraries + will return row-wise data to Python *before* we can load into Arrow. Note that + you can determine the connection's URI from a SQLAlchemy engine object by calling `conn.engine.url.render_as_string(hide_password=False)`. * If polars has to create a cursor from your connection in order to execute the query then that cursor will be automatically closed when the query completes; - however, polars will *never* close any other connection or cursor. + however, polars will *never* close any other open connection or cursor. + + * We are able to support more than just relational databases and SQL queries + through this function. For example, we can load graph database results from + a `KùzuDB` connection in conjunction with a Cypher query. See Also -------- @@ -539,10 +616,18 @@ def read_database( # noqa: D417 ... batch_size=1000, ... ): ... do_something(df) # doctest: +SKIP + + Load graph data query results from a `KùzuDB` connection and a Cypher query: + + >>> df = pl.read_database( + ... query="MATCH (a:User)-[f:Follows]->(b:User) RETURN a.name, f.since, b.name", + ... connection=kuzu_db_conn, + ... ) # doctest: +SKIP + """ # noqa: W505 if isinstance(connection, str): # check for odbc connection string - if re.sub(r"\s", "", connection[:20]).lower().startswith("driver="): + if re.search(r"\bdriver\s*=\s*{[^}]+?}", connection, re.IGNORECASE): try: import arrow_odbc # noqa: F401 except ModuleNotFoundError: @@ -587,6 +672,7 @@ def read_database( # noqa: D417 batch_size=batch_size, iter_batches=iter_batches, schema_overrides=schema_overrides, + infer_schema_length=infer_schema_length, ) @@ -600,6 +686,7 @@ def read_database_uri( protocol: str | None = None, engine: DbReadEngine | None = None, schema_overrides: SchemaDict | None = None, + execute_options: dict[str, Any] | None = None, ) -> DataFrame: """ Read the results of a SQL query into a DataFrame, given a URI. @@ -646,6 +733,9 @@ def read_database_uri( schema_overrides A dictionary mapping column names to dtypes, used to override the schema given in the data returned by the query. + execute_options + These options will be passed to the underlying query execution method as + kwargs. Note that connectorx does not support this parameter. Notes ----- @@ -714,6 +804,9 @@ def read_database_uri( engine = "connectorx" if engine == "connectorx": + if execute_options: + msg = "the 'connectorx' engine does not support use of `execute_options`" + raise ValueError(msg) return _read_sql_connectorx( query, connection_uri=uri, @@ -727,7 +820,12 @@ def read_database_uri( if not isinstance(query, str): msg = "only a single SQL query string is accepted for adbc" raise ValueError(msg) - return _read_sql_adbc(query, uri, schema_overrides) + return _read_sql_adbc( + query, + connection_uri=uri, + schema_overrides=schema_overrides, + execute_options=execute_options, + ) else: msg = f"engine must be one of {{'connectorx', 'adbc'}}, got {engine!r}" raise ValueError(msg) @@ -767,10 +865,13 @@ def _read_sql_connectorx( def _read_sql_adbc( - query: str, connection_uri: str, schema_overrides: SchemaDict | None + query: str, + connection_uri: str, + schema_overrides: SchemaDict | None, + execute_options: dict[str, Any] | None = None, ) -> DataFrame: with _open_adbc_connection(connection_uri) as conn, conn.cursor() as cursor: - cursor.execute(query) + cursor.execute(query, **(execute_options or {})) tbl = cursor.fetch_arrow_table() return from_arrow(tbl, schema_overrides=schema_overrides) # type: ignore[return-value] diff --git a/py-polars/polars/io/iceberg.py b/py-polars/polars/io/iceberg.py index 558604150f0cf..ec57b401c907e 100644 --- a/py-polars/polars/io/iceberg.py +++ b/py-polars/polars/io/iceberg.py @@ -21,8 +21,8 @@ from typing import TYPE_CHECKING, Any, Callable import polars._reexport as pl +from polars._utils.convert import to_py_date, to_py_datetime from polars.dependencies import pyiceberg -from polars.utils.convert import _to_python_date, _to_python_datetime if TYPE_CHECKING: from datetime import date, datetime @@ -34,8 +34,8 @@ __all__ = ["scan_iceberg"] _temporal_conversions: dict[str, Callable[..., datetime | date]] = { - "_to_python_date": _to_python_date, - "_to_python_datetime": _to_python_datetime, + "to_py_date": to_py_date, + "to_py_datetime": to_py_datetime, } diff --git a/py-polars/polars/io/ipc/functions.py b/py-polars/polars/io/ipc/functions.py index 8ca0c5b3af4b9..55d07848cb5e8 100644 --- a/py-polars/polars/io/ipc/functions.py +++ b/py-polars/polars/io/ipc/functions.py @@ -2,27 +2,25 @@ import contextlib from pathlib import Path -from typing import IO, TYPE_CHECKING, Any, BinaryIO +from typing import IO, TYPE_CHECKING, Any import polars._reexport as pl +from polars._utils.deprecation import deprecate_renamed_parameter +from polars._utils.various import normalize_filepath from polars.dependencies import _PYARROW_AVAILABLE from polars.io._utils import _prepare_file_arg -from polars.utils.deprecation import deprecate_renamed_parameter -from polars.utils.various import normalize_filepath with contextlib.suppress(ImportError): from polars.polars import read_ipc_schema as _read_ipc_schema if TYPE_CHECKING: - from io import BytesIO - from polars import DataFrame, DataType, LazyFrame @deprecate_renamed_parameter("row_count_name", "row_index_name", version="0.20.4") @deprecate_renamed_parameter("row_count_offset", "row_index_offset", version="0.20.4") def read_ipc( - source: str | BinaryIO | BytesIO | Path | bytes, + source: str | Path | IO[bytes] | bytes, *, columns: list[int] | list[str] | None = None, n_rows: int | None = None, @@ -114,7 +112,7 @@ def read_ipc( @deprecate_renamed_parameter("row_count_name", "row_index_name", version="0.20.4") @deprecate_renamed_parameter("row_count_offset", "row_index_offset", version="0.20.4") def read_ipc_stream( - source: str | BinaryIO | BytesIO | Path | bytes, + source: str | Path | IO[bytes] | bytes, *, columns: list[int] | list[str] | None = None, n_rows: int | None = None, @@ -224,6 +222,7 @@ def scan_ipc( row_index_offset: int = 0, storage_options: dict[str, Any] | None = None, memory_map: bool = True, + retries: int = 0, ) -> LazyFrame: """ Lazily read from an Arrow IPC (Feather v2) file or multiple files via glob patterns. @@ -254,6 +253,9 @@ def scan_ipc( Try to memory map the file. This can greatly improve performance on repeated queries as the OS may cache pages. Only uncompressed IPC files can be memory mapped. + retries + Number of retries if accessing a cloud instance fails. + """ return pl.LazyFrame._scan_ipc( source, @@ -264,4 +266,5 @@ def scan_ipc( row_index_offset=row_index_offset, storage_options=storage_options, memory_map=memory_map, + retries=retries, ) diff --git a/py-polars/polars/io/ndjson.py b/py-polars/polars/io/ndjson.py index d8e5aa9d403a6..9d413e6de10dc 100644 --- a/py-polars/polars/io/ndjson.py +++ b/py-polars/polars/io/ndjson.py @@ -3,8 +3,8 @@ from typing import TYPE_CHECKING import polars._reexport as pl +from polars._utils.deprecation import deprecate_renamed_parameter from polars.datatypes import N_INFER_DEFAULT -from polars.utils.deprecation import deprecate_renamed_parameter if TYPE_CHECKING: from io import IOBase diff --git a/py-polars/polars/io/parquet/functions.py b/py-polars/polars/io/parquet/functions.py index a80da70569a31..8c0acd413a9f7 100644 --- a/py-polars/polars/io/parquet/functions.py +++ b/py-polars/polars/io/parquet/functions.py @@ -6,11 +6,11 @@ from typing import IO, TYPE_CHECKING, Any import polars._reexport as pl +from polars._utils.deprecation import deprecate_renamed_parameter +from polars._utils.various import is_int_sequence, normalize_filepath from polars.convert import from_arrow from polars.dependencies import _PYARROW_AVAILABLE from polars.io._utils import _prepare_file_arg -from polars.utils.deprecation import deprecate_renamed_parameter -from polars.utils.various import is_int_sequence, normalize_filepath with contextlib.suppress(ImportError): from polars.polars import read_parquet_schema as _read_parquet_schema diff --git a/py-polars/polars/io/pyarrow_dataset/anonymous_scan.py b/py-polars/polars/io/pyarrow_dataset/anonymous_scan.py index 2bae55a8be235..5ecaeacff2ca6 100644 --- a/py-polars/polars/io/pyarrow_dataset/anonymous_scan.py +++ b/py-polars/polars/io/pyarrow_dataset/anonymous_scan.py @@ -71,13 +71,13 @@ def _scan_pyarrow_dataset_impl( _filter = None if predicate: - from polars.datatypes import Date, Datetime, Duration - from polars.utils.convert import ( - _to_python_date, - _to_python_datetime, - _to_python_time, - _to_python_timedelta, + from polars._utils.convert import ( + to_py_date, + to_py_datetime, + to_py_time, + to_py_timedelta, ) + from polars.datatypes import Date, Datetime, Duration _filter = eval( predicate, @@ -86,10 +86,10 @@ def _scan_pyarrow_dataset_impl( "Date": Date, "Datetime": Datetime, "Duration": Duration, - "_to_python_date": _to_python_date, - "_to_python_datetime": _to_python_datetime, - "_to_python_time": _to_python_time, - "_to_python_timedelta": _to_python_timedelta, + "to_py_date": to_py_date, + "to_py_datetime": to_py_datetime, + "to_py_time": to_py_time, + "to_py_timedelta": to_py_timedelta, }, ) diff --git a/py-polars/polars/io/pyarrow_dataset/functions.py b/py-polars/polars/io/pyarrow_dataset/functions.py index f1d6edf8b1ebf..2d23a95bddfde 100644 --- a/py-polars/polars/io/pyarrow_dataset/functions.py +++ b/py-polars/polars/io/pyarrow_dataset/functions.py @@ -2,8 +2,8 @@ from typing import TYPE_CHECKING +from polars._utils.unstable import unstable from polars.io.pyarrow_dataset.anonymous_scan import _scan_pyarrow_dataset -from polars.utils.unstable import unstable if TYPE_CHECKING: from polars import LazyFrame diff --git a/py-polars/polars/io/spreadsheet/_write_utils.py b/py-polars/polars/io/spreadsheet/_write_utils.py index 43ce36c4f57a7..e13aff6188ba2 100644 --- a/py-polars/polars/io/spreadsheet/_write_utils.py +++ b/py-polars/polars/io/spreadsheet/_write_utils.py @@ -161,8 +161,7 @@ def _xl_column_range( *, include_header: bool, as_range: Literal[True] = ..., -) -> str: - ... +) -> str: ... @overload @@ -173,8 +172,7 @@ def _xl_column_range( *, include_header: bool, as_range: Literal[False], -) -> tuple[int, int, int, int]: - ... +) -> tuple[int, int, int, int]: ... def _xl_column_range( diff --git a/py-polars/polars/io/spreadsheet/functions.py b/py-polars/polars/io/spreadsheet/functions.py index fe15013b91223..4fde7f7a9ea9b 100644 --- a/py-polars/polars/io/spreadsheet/functions.py +++ b/py-polars/polars/io/spreadsheet/functions.py @@ -5,10 +5,12 @@ from datetime import time from io import BufferedReader, BytesIO, StringIO from pathlib import Path -from typing import TYPE_CHECKING, Any, BinaryIO, Callable, NoReturn, Sequence, overload +from typing import IO, TYPE_CHECKING, Any, Callable, NoReturn, Sequence, overload import polars._reexport as pl from polars import functions as F +from polars._utils.deprecation import deprecate_renamed_parameter +from polars._utils.various import normalize_filepath from polars.datatypes import ( FLOAT_DTYPES, NUMERIC_DTYPES, @@ -22,8 +24,6 @@ from polars.exceptions import NoDataError, ParameterCollisionError from polars.io._utils import PortableTemporaryFile, _looks_like_url, _process_file_url from polars.io.csv.functions import read_csv -from polars.utils.deprecation import deprecate_renamed_parameter -from polars.utils.various import normalize_filepath if TYPE_CHECKING: from typing import Literal @@ -33,7 +33,7 @@ @overload def read_excel( - source: str | BytesIO | Path | BinaryIO | bytes, + source: str | Path | IO[bytes] | bytes, *, sheet_id: None = ..., sheet_name: str, @@ -42,13 +42,12 @@ def read_excel( read_options: dict[str, Any] | None = ..., schema_overrides: SchemaDict | None = ..., raise_if_empty: bool = ..., -) -> pl.DataFrame: - ... +) -> pl.DataFrame: ... @overload def read_excel( - source: str | BytesIO | Path | BinaryIO | bytes, + source: str | Path | IO[bytes] | bytes, *, sheet_id: None = ..., sheet_name: None = ..., @@ -57,13 +56,12 @@ def read_excel( read_options: dict[str, Any] | None = ..., schema_overrides: SchemaDict | None = ..., raise_if_empty: bool = ..., -) -> pl.DataFrame: - ... +) -> pl.DataFrame: ... @overload def read_excel( - source: str | BytesIO | Path | BinaryIO | bytes, + source: str | Path | IO[bytes] | bytes, *, sheet_id: int, sheet_name: str, @@ -72,15 +70,14 @@ def read_excel( read_options: dict[str, Any] | None = ..., schema_overrides: SchemaDict | None = ..., raise_if_empty: bool = ..., -) -> NoReturn: - ... +) -> NoReturn: ... # note: 'ignore' required as mypy thinks that the return value for # Literal[0] overlaps with the return value for other integers @overload # type: ignore[overload-overlap] def read_excel( - source: str | BytesIO | Path | BinaryIO | bytes, + source: str | Path | IO[bytes] | bytes, *, sheet_id: Literal[0] | Sequence[int], sheet_name: None = ..., @@ -89,13 +86,12 @@ def read_excel( read_options: dict[str, Any] | None = ..., schema_overrides: SchemaDict | None = ..., raise_if_empty: bool = ..., -) -> dict[str, pl.DataFrame]: - ... +) -> dict[str, pl.DataFrame]: ... @overload def read_excel( - source: str | BytesIO | Path | BinaryIO | bytes, + source: str | Path | IO[bytes] | bytes, *, sheet_id: int, sheet_name: None = ..., @@ -104,13 +100,12 @@ def read_excel( read_options: dict[str, Any] | None = ..., schema_overrides: SchemaDict | None = ..., raise_if_empty: bool = ..., -) -> pl.DataFrame: - ... +) -> pl.DataFrame: ... @overload def read_excel( - source: str | BytesIO | Path | BinaryIO | bytes, + source: str | Path | IO[bytes] | bytes, *, sheet_id: None, sheet_name: list[str] | tuple[str], @@ -119,14 +114,13 @@ def read_excel( read_options: dict[str, Any] | None = ..., schema_overrides: SchemaDict | None = ..., raise_if_empty: bool = ..., -) -> dict[str, pl.DataFrame]: - ... +) -> dict[str, pl.DataFrame]: ... @deprecate_renamed_parameter("xlsx2csv_options", "engine_options", version="0.20.6") @deprecate_renamed_parameter("read_csv_options", "read_options", version="0.20.7") def read_excel( - source: str | BytesIO | Path | BinaryIO | bytes, + source: str | Path | IO[bytes] | bytes, *, sheet_id: int | Sequence[int] | None = None, sheet_name: str | list[str] | tuple[str] | None = None, @@ -270,78 +264,72 @@ def read_excel( @overload def read_ods( - source: str | BytesIO | Path | BinaryIO | bytes, + source: str | Path | IO[bytes] | bytes, *, sheet_id: None = ..., sheet_name: str, schema_overrides: SchemaDict | None = None, raise_if_empty: bool = ..., -) -> pl.DataFrame: - ... +) -> pl.DataFrame: ... @overload def read_ods( - source: str | BytesIO | Path | BinaryIO | bytes, + source: str | Path | IO[bytes] | bytes, *, sheet_id: None = ..., sheet_name: None = ..., schema_overrides: SchemaDict | None = None, raise_if_empty: bool = ..., -) -> pl.DataFrame: - ... +) -> pl.DataFrame: ... @overload def read_ods( - source: str | BytesIO | Path | BinaryIO | bytes, + source: str | Path | IO[bytes] | bytes, *, sheet_id: int, sheet_name: str, schema_overrides: SchemaDict | None = None, raise_if_empty: bool = ..., -) -> NoReturn: - ... +) -> NoReturn: ... @overload # type: ignore[overload-overlap] def read_ods( - source: str | BytesIO | Path | BinaryIO | bytes, + source: str | Path | IO[bytes] | bytes, *, sheet_id: Literal[0] | Sequence[int], sheet_name: None = ..., schema_overrides: SchemaDict | None = None, raise_if_empty: bool = ..., -) -> dict[str, pl.DataFrame]: - ... +) -> dict[str, pl.DataFrame]: ... @overload def read_ods( - source: str | BytesIO | Path | BinaryIO | bytes, + source: str | Path | IO[bytes] | bytes, *, sheet_id: int, sheet_name: None = ..., schema_overrides: SchemaDict | None = None, raise_if_empty: bool = ..., -) -> pl.DataFrame: - ... +) -> pl.DataFrame: ... @overload def read_ods( - source: str | BytesIO | Path | BinaryIO | bytes, + source: str | Path | IO[bytes] | bytes, *, sheet_id: None, sheet_name: list[str] | tuple[str], schema_overrides: SchemaDict | None = None, raise_if_empty: bool = ..., -) -> dict[str, pl.DataFrame]: - ... +) -> dict[str, pl.DataFrame]: ... def read_ods( - source: str | BytesIO | Path | BinaryIO | bytes, + source: str | Path | IO[bytes] | bytes, *, sheet_id: int | Sequence[int] | None = None, sheet_name: str | list[str] | tuple[str] | None = None, @@ -406,7 +394,7 @@ def read_ods( ) -def _identify_from_magic_bytes(data: bytes | BinaryIO | BytesIO) -> str | None: +def _identify_from_magic_bytes(data: IO[bytes] | bytes) -> str | None: if isinstance(data, bytes): data = BytesIO(data) @@ -425,7 +413,7 @@ def _identify_from_magic_bytes(data: bytes | BinaryIO | BytesIO) -> str | None: data.seek(initial_position) -def _identify_workbook(wb: str | bytes | Path | BinaryIO | BytesIO) -> str | None: +def _identify_workbook(wb: str | Path | IO[bytes] | bytes) -> str | None: """Use file extension (and magic bytes) to identify Workbook type.""" if not isinstance(wb, (str, Path)): # raw binary data (bytesio, etc) @@ -449,7 +437,7 @@ def _identify_workbook(wb: str | bytes | Path | BinaryIO | BytesIO) -> str | Non def _read_spreadsheet( sheet_id: int | Sequence[int] | None, sheet_name: str | list[str] | tuple[str] | None, - source: str | BytesIO | Path | BinaryIO | bytes, + source: str | Path | IO[bytes] | bytes, engine: ExcelSpreadsheetEngine | Literal["ods"] | None, engine_options: dict[str, Any] | None = None, read_options: dict[str, Any] | None = None, @@ -557,7 +545,7 @@ def _get_sheet_names( def _initialise_spreadsheet_parser( engine: str | None, - source: str | BytesIO | Path | BinaryIO | bytes, + source: str | Path | IO[bytes] | bytes, engine_options: dict[str, Any], ) -> tuple[Callable[..., pl.DataFrame], Any, list[dict[str, Any]]]: """Instantiate the indicated spreadsheet parser and establish related properties.""" diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index 1f228202e2968..8947e9618f20f 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -24,6 +24,31 @@ import polars._reexport as pl from polars import functions as F +from polars._utils.async_ import _AioDataFrameResult, _GeventDataFrameResult +from polars._utils.convert import negate_duration_string, parse_as_duration_string +from polars._utils.deprecation import ( + deprecate_function, + deprecate_parameter_as_positional, + deprecate_renamed_function, + deprecate_renamed_parameter, + deprecate_saturating, + issue_deprecation_warning, +) +from polars._utils.parse_expr_input import ( + parse_as_expression, + parse_as_list_of_expressions, +) +from polars._utils.unstable import issue_unstable_warning, unstable +from polars._utils.various import ( + _in_notebook, + _prepare_row_index_args, + _process_null_values, + is_bool_sequence, + is_sequence, + normalize_filepath, + parse_percentiles, +) +from polars._utils.wrap import wrap_df, wrap_expr from polars.convert import from_dict from polars.datatypes import ( DTYPE_TEMPORAL_UNITS, @@ -62,31 +87,6 @@ from polars.lazyframe.in_process import InProcessQuery from polars.selectors import _expand_selectors, by_dtype, expand_selector from polars.slice import LazyPolarsSlice -from polars.utils._async import _AioDataFrameResult, _GeventDataFrameResult -from polars.utils._parse_expr_input import ( - parse_as_expression, - parse_as_list_of_expressions, -) -from polars.utils._wrap import wrap_df, wrap_expr -from polars.utils.convert import _negate_duration, _timedelta_to_pl_duration -from polars.utils.deprecation import ( - deprecate_function, - deprecate_parameter_as_positional, - deprecate_renamed_function, - deprecate_renamed_parameter, - deprecate_saturating, - issue_deprecation_warning, -) -from polars.utils.unstable import issue_unstable_warning, unstable -from polars.utils.various import ( - _in_notebook, - _prepare_row_index_args, - _process_null_values, - is_bool_sequence, - is_sequence, - normalize_filepath, - parse_percentiles, -) with contextlib.suppress(ImportError): # Module not available when building docs from polars.polars import PyLazyFrame @@ -166,15 +166,19 @@ class LazyFrame: The number of entries in the schema should match the underlying data dimensions, unless a sequence of dictionaries is being passed, in which case a *partial* schema can be declared to prevent specific fields from being loaded. + strict : bool, default True + Throw an error if any `data` value does not exactly match the given or inferred + data type for that column. If set to `False`, values that do not match the data + type are cast to that data type or, if casting is not possible, set to null + instead. orient : {'col', 'row'}, default None Whether to interpret two-dimensional data as columns or as rows. If None, the orientation is inferred by matching the columns and data dimensions. If this does not yield conclusive results, column orientation is used. infer_schema_length : int or None - The maximum number of rows to scan for schema inference. - If set to `None`, the full data may be scanned *(this is slow)*. - This parameter only applies if the input data is a sequence or generator of - rows; other input is read as-is. + The maximum number of rows to scan for schema inference. If set to `None`, the + full data may be scanned *(this can be slow)*. This parameter only applies if + the input data is a sequence or generator of rows; other input is read as-is. nan_to_null : bool, default False If the data comes from one or more numpy arrays, can optionally convert input data np.nan values to null instead. This is a no-op for all other input data. @@ -295,6 +299,7 @@ def __init__( schema: SchemaDefinition | None = None, *, schema_overrides: SchemaDict | None = None, + strict: bool = True, orient: Orientation | None = None, infer_schema_length: int | None = N_INFER_DEFAULT, nan_to_null: bool = False, @@ -306,6 +311,7 @@ def __init__( data=data, schema=schema, schema_overrides=schema_overrides, + strict=strict, orient=orient, infer_schema_length=infer_schema_length, nan_to_null=nan_to_null, @@ -491,6 +497,7 @@ def _scan_ipc( row_index_offset: int = 0, storage_options: dict[str, object] | None = None, memory_map: bool = True, + retries: int = 0, ) -> Self: """ Lazily read from an Arrow IPC (Feather v2) file. @@ -528,6 +535,8 @@ def _scan_ipc( rechunk, _prepare_row_index_args(row_index_name, row_index_offset), memory_map=memory_map, + cloud_options=storage_options, + retries=retries, ) return self @@ -799,12 +808,10 @@ def _repr_html_(self) -> str: """ @overload - def serialize(self, file: None = ...) -> str: - ... + def serialize(self, file: None = ...) -> str: ... @overload - def serialize(self, file: IOBase | str | Path) -> None: - ... + def serialize(self, file: IOBase | str | Path) -> None: ... def serialize(self, file: IOBase | str | Path | None = None) -> str | None: """ @@ -1001,10 +1008,11 @@ def describe( Customize which percentiles are displayed, applying linear interpolation: - >>> lf.describe( - ... percentiles=[0.1, 0.3, 0.5, 0.7, 0.9], - ... interpolation="linear", - ... ) + >>> with pl.Config(tbl_rows=12): + ... lf.describe( + ... percentiles=[0.1, 0.3, 0.5, 0.7, 0.9], + ... interpolation="linear", + ... ) shape: (11, 7) ┌────────────┬──────────┬──────────┬──────────┬──────┬────────────┬──────────┐ │ statistic ┆ float ┆ int ┆ bool ┆ str ┆ date ┆ time │ @@ -1784,8 +1792,7 @@ def collect( streaming: bool = False, background: Literal[True], _eager: bool = False, - ) -> InProcessQuery: - ... + ) -> InProcessQuery: ... @overload def collect( @@ -1802,8 +1809,7 @@ def collect( streaming: bool = False, background: Literal[False] = False, _eager: bool = False, - ) -> DataFrame: - ... + ) -> DataFrame: ... def collect( self, @@ -1950,8 +1956,7 @@ def collect_async( comm_subplan_elim: bool = True, comm_subexpr_elim: bool = True, streaming: bool = True, - ) -> _GeventDataFrameResult[DataFrame]: - ... + ) -> _GeventDataFrameResult[DataFrame]: ... @overload def collect_async( @@ -1967,8 +1972,7 @@ def collect_async( comm_subplan_elim: bool = True, comm_subexpr_elim: bool = True, streaming: bool = True, - ) -> Awaitable[DataFrame]: - ... + ) -> Awaitable[DataFrame]: ... def collect_async( self, @@ -3233,7 +3237,7 @@ def rolling( check_sorted: bool = True, ) -> LazyGroupBy: """ - Create rolling groups based on a time, Int32, or Int64 column. + Create rolling groups based on a temporal or integer column. Different from a `dynamic_group_by` the windows are now determined by the individual values and are not of constant intervals. For constant intervals @@ -3277,11 +3281,6 @@ def rolling( not be 24 hours, due to daylight savings). Similarly for "calendar week", "calendar month", "calendar quarter", and "calendar year". - In case of a rolling operation on an integer column, the windows are defined by: - - - "1i" # length 1 - - "10i" # length 10 - Parameters ---------- index_column @@ -3291,8 +3290,8 @@ def rolling( then it must be sorted in ascending order within each group). In case of a rolling group by on indices, dtype needs to be one of - {Int32, Int64}. Note that Int32 gets temporarily cast to Int64, so if - performance matters use an Int64 column. + {UInt32, UInt64, Int32, Int64}. Note that the first three get temporarily + cast to Int64, so if performance matters use an Int64 column. period length of the window - must be non-negative offset @@ -3360,11 +3359,11 @@ def rolling( offset = deprecate_saturating(offset) index_column = parse_as_expression(index_column) if offset is None: - offset = _negate_duration(_timedelta_to_pl_duration(period)) + offset = negate_duration_string(parse_as_duration_string(period)) pyexprs_by = parse_as_list_of_expressions(by) if by is not None else [] - period = _timedelta_to_pl_duration(period) - offset = _timedelta_to_pl_duration(offset) + period = parse_as_duration_string(period) + offset = parse_as_duration_string(offset) lgb = self._ldf.rolling( index_column, period, offset, closed, pyexprs_by, check_sorted @@ -3706,14 +3705,14 @@ def group_by_dynamic( index_column = parse_as_expression(index_column) if offset is None: - offset = _negate_duration(_timedelta_to_pl_duration(every)) + offset = negate_duration_string(parse_as_duration_string(every)) if period is None: period = every - period = _timedelta_to_pl_duration(period) - offset = _timedelta_to_pl_duration(offset) - every = _timedelta_to_pl_duration(every) + period = parse_as_duration_string(period) + offset = parse_as_duration_string(offset) + every = parse_as_duration_string(every) pyexprs_by = parse_as_list_of_expressions(by) if by is not None else [] lgb = self._ldf.group_by_dynamic( @@ -3890,7 +3889,7 @@ def join_asof( if isinstance(tolerance, str): tolerance_str = tolerance elif isinstance(tolerance, timedelta): - tolerance_str = _timedelta_to_pl_duration(tolerance) + tolerance_str = parse_as_duration_string(tolerance) else: tolerance_num = tolerance @@ -3952,7 +3951,7 @@ def join( * *outer_coalesce* Same as 'outer', but coalesces the key columns * *cross* - Returns the cartisian product of rows from both tables + Returns the Cartesian product of rows from both tables * *semi* Filter rows that have a match in the right table. * *anti* @@ -4813,10 +4812,16 @@ def first(self) -> Self: """ return self.slice(0, 1) + @deprecate_function( + "Use `select(pl.all().approx_n_unique())` instead.", version="0.20.11" + ) def approx_n_unique(self) -> Self: """ Approximate count of unique values. + .. deprecated:: 0.20.11 + Use `select(pl.all().approx_n_unique())` instead. + This is done using the HyperLogLog++ algorithm for cardinality estimation. Examples @@ -4827,7 +4832,7 @@ def approx_n_unique(self) -> Self: ... "b": [1, 2, 1, 1], ... } ... ) - >>> lf.approx_n_unique().collect() + >>> lf.approx_n_unique().collect() # doctest: +SKIP shape: (1, 2) ┌─────┬─────┐ │ a ┆ b │ diff --git a/py-polars/polars/lazyframe/group_by.py b/py-polars/polars/lazyframe/group_by.py index b8e3aa588c7cc..97f118fa06a58 100644 --- a/py-polars/polars/lazyframe/group_by.py +++ b/py-polars/polars/lazyframe/group_by.py @@ -3,9 +3,9 @@ from typing import TYPE_CHECKING, Callable, Iterable from polars import functions as F -from polars.utils._parse_expr_input import parse_as_list_of_expressions -from polars.utils._wrap import wrap_ldf -from polars.utils.deprecation import deprecate_renamed_function +from polars._utils.deprecation import deprecate_renamed_function +from polars._utils.parse_expr_input import parse_as_list_of_expressions +from polars._utils.wrap import wrap_ldf if TYPE_CHECKING: from polars import DataFrame, LazyFrame @@ -347,16 +347,16 @@ def len(self) -> LazyFrame: ... "b": [1, None, 2], ... } ... ) - >>> lf.group_by("a").count().collect() # doctest: +SKIP + >>> lf.group_by("a").len().collect() # doctest: +SKIP shape: (2, 2) - ┌────────┬───────┐ - │ a ┆ count │ - │ --- ┆ --- │ - │ str ┆ u32 │ - ╞════════╪═══════╡ - │ apple ┆ 2 │ - │ orange ┆ 1 │ - └────────┴───────┘ + ┌────────┬─────┐ + │ a ┆ len │ + │ --- ┆ --- │ + │ str ┆ u32 │ + ╞════════╪═════╡ + │ apple ┆ 2 │ + │ orange ┆ 1 │ + └────────┴─────┘ """ return self.agg(F.len()) @@ -365,6 +365,9 @@ def count(self) -> LazyFrame: """ Return the number of rows in each group. + .. deprecated:: 0.20.5 + This method has been renamed to :func:`LazyGroupBy.len`. + Rows containing null values count towards the total. Examples diff --git a/py-polars/polars/lazyframe/in_process.py b/py-polars/polars/lazyframe/in_process.py index 3c04020a2f688..9bcee8dccfa07 100644 --- a/py-polars/polars/lazyframe/in_process.py +++ b/py-polars/polars/lazyframe/in_process.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING -from polars.utils._wrap import wrap_df +from polars._utils.wrap import wrap_df if TYPE_CHECKING: from polars import DataFrame diff --git a/py-polars/polars/meta/__init__.py b/py-polars/polars/meta/__init__.py index b9e84653ebc84..8328c8db87f29 100644 --- a/py-polars/polars/meta/__init__.py +++ b/py-polars/polars/meta/__init__.py @@ -1,4 +1,5 @@ """Public functions that provide information about the Polars package or the environment it runs in.""" # noqa: W505 + from polars.meta.build import build_info from polars.meta.index_type import get_index_type from polars.meta.thread_pool import thread_pool_size, threadpool_size diff --git a/py-polars/polars/meta/build.py b/py-polars/polars/meta/build.py index d38d92fc4414f..4b6abcf5fcf91 100644 --- a/py-polars/polars/meta/build.py +++ b/py-polars/polars/meta/build.py @@ -2,7 +2,7 @@ from typing import Any -from polars.utils._polars_version import get_polars_version +from polars._utils.polars_version import get_polars_version try: from polars.polars import __build__ diff --git a/py-polars/polars/meta/thread_pool.py b/py-polars/polars/meta/thread_pool.py index 446eb486ceb21..c86f358e7e6eb 100644 --- a/py-polars/polars/meta/thread_pool.py +++ b/py-polars/polars/meta/thread_pool.py @@ -2,7 +2,7 @@ import contextlib -from polars.utils.deprecation import deprecate_renamed_function +from polars._utils.deprecation import deprecate_renamed_function with contextlib.suppress(ImportError): # Module not available when building docs import polars.polars as plr diff --git a/py-polars/polars/meta/versions.py b/py-polars/polars/meta/versions.py index 305b2e2a17f8a..6069aaecf3383 100644 --- a/py-polars/polars/meta/versions.py +++ b/py-polars/polars/meta/versions.py @@ -2,8 +2,8 @@ import sys +from polars._utils.polars_version import get_polars_version from polars.meta.index_type import get_index_type -from polars.utils._polars_version import get_polars_version def show_versions() -> None: @@ -14,29 +14,30 @@ def show_versions() -> None: -------- >>> pl.show_versions() # doctest: +SKIP --------Version info--------- - Polars: 0.19.16 + Polars: 0.20.14 Index type: UInt32 - Platform: macOS-14.1.1-arm64-arm-64bit - Python: 3.11.6 (main, Oct 2 2023, 13:45:54) [Clang 15.0.0 (clang-1500.0.40.1)] + Platform: macOS-14.3.1-arm64-arm-64bit + Python: 3.11.8 (main, Feb 6 2024, 21:21:21) [Clang 15.0.0 (clang-1500.1.0.2.5)] ----Optional dependencies---- - adbc_driver_manager: 0.8.0 + adbc_driver_manager: 0.10.0 cloudpickle: 3.0.0 connectorx: 0.3.2 - deltalake: 0.13.0 - fsspec: 2023.10.0 - hvplot: 0.9.1 - gevent: 23.9.1 - matplotlib: 3.8.2 - numpy: 1.26.2 + deltalake: 0.16.0 + fastexcel: 0.9.1 + fsspec: 2023.12.2 + gevent: 24.2.1 + hvplot: 0.9.2 + matplotlib: 3.8.3 + numpy: 1.26.4 openpyxl: 3.1.2 - pandas: 2.1.3 - pyarrow: 14.0.1 - pydantic: 2.5.2 - pyiceberg: 0.5.1 + pandas: 2.2.1 + pyarrow: 15.0.0 + pydantic: 2.6.3 + pyiceberg: 0.6.0 pyxlsb: 1.0.10 - sqlalchemy: 2.0.23 - xlsx2csv: 0.8.1 - xlsxwriter: 3.1.9 + sqlalchemy: 2.0.28 + xlsx2csv: 0.8.2 + xlsxwriter: 3.2.0 """ # noqa: W505 # Note: we import 'platform' here (rather than at the top of the # module) as a micro-optimization for polars' initial import @@ -64,6 +65,7 @@ def _get_dependency_info() -> dict[str, str]: "cloudpickle", "connectorx", "deltalake", + "fastexcel", "fsspec", "gevent", "hvplot", diff --git a/py-polars/polars/plugins.py b/py-polars/polars/plugins.py new file mode 100644 index 0000000000000..44a5f1a3b7a54 --- /dev/null +++ b/py-polars/polars/plugins.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +import contextlib +from pathlib import Path +from typing import TYPE_CHECKING, Any, Iterable + +from polars._utils.parse_expr_input import parse_as_list_of_expressions +from polars._utils.wrap import wrap_expr + +with contextlib.suppress(ImportError): # Module not available when building docs + import polars.polars as plr + +if TYPE_CHECKING: + from polars import Expr + from polars.type_aliases import IntoExpr + +__all__ = ["register_plugin_function"] + + +def register_plugin_function( + *, + plugin_path: Path | str, + function_name: str, + args: IntoExpr | Iterable[IntoExpr], + kwargs: dict[str, Any] | None = None, + is_elementwise: bool = False, + changes_length: bool = False, + returns_scalar: bool = False, + cast_to_supertype: bool = False, + input_wildcard_expansion: bool = False, + pass_name_to_apply: bool = False, +) -> Expr: + """ + Register a plugin function. + + See the `user guide `_ + for more information about plugins. + + Parameters + ---------- + plugin_path + Path to the plugin package. Accepts either the file path to the dynamic library + file or the path to the directory containing it. + function_name + The name of the Rust function to register. + args + The arguments passed to this function. These get passed to the `input` + argument on the Rust side, and have to be expressions (or be convertible + to expressions). + kwargs + Non-expression arguments to the plugin function. These must be + JSON serializable. + is_elementwise + Indicate that the function operates on scalars only. This will potentially + trigger fast paths. + changes_length + Indicate that the function will change the length of the expression. + For example, a `unique` or `slice` operation. + returns_scalar + Automatically explode on unit length if the function ran as final aggregation. + This is the case for aggregations like `sum`, `min`, `covariance` etc. + cast_to_supertype + Cast the input expressions to their supertype. + input_wildcard_expansion + Expand wildcard expressions before executing the function. + pass_name_to_apply + If set to `True`, the `Series` passed to the function in a group-by operation + will ensure the name is set. This is an extra heap allocation per group. + + Returns + ------- + Expr + + Warnings + -------- + This is highly unsafe as this will call the C function loaded by + `plugin::function_name`. + + The parameters you set dictate how Polars will handle the function. + Make sure they are correct! + """ + pyexprs = parse_as_list_of_expressions(args) + serialized_kwargs = _serialize_kwargs(kwargs) + plugin_path = _resolve_plugin_path(plugin_path) + + return wrap_expr( + plr.register_plugin_function( + plugin_path=str(plugin_path), + function_name=function_name, + args=pyexprs, + kwargs=serialized_kwargs, + is_elementwise=is_elementwise, + input_wildcard_expansion=input_wildcard_expansion, + returns_scalar=returns_scalar, + cast_to_supertype=cast_to_supertype, + pass_name_to_apply=pass_name_to_apply, + changes_length=changes_length, + ) + ) + + +def _serialize_kwargs(kwargs: dict[str, Any] | None) -> bytes: + """Serialize the function's keyword arguments.""" + if not kwargs: + return b"" + + import pickle + + # Use the highest pickle protocol supported the serde-pickle crate: + # https://docs.rs/serde-pickle/latest/serde_pickle/ + return pickle.dumps(kwargs, protocol=5) + + +def _resolve_plugin_path(path: Path | str) -> Path: + """Get the file path of the dynamic library file.""" + if not isinstance(path, Path): + path = Path(path) + + if path.is_file(): + return path.resolve() + + for p in path.iterdir(): + if _is_dynamic_lib(p): + return p.resolve() + else: + msg = f"no dynamic library found at path: {path}" + raise FileNotFoundError(msg) + + +def _is_dynamic_lib(path: Path) -> bool: + return path.is_file() and path.suffix in (".so", ".dll", ".pyd") diff --git a/py-polars/polars/selectors.py b/py-polars/polars/selectors.py index 0793f5b9ee570..b732c53bab42a 100644 --- a/py-polars/polars/selectors.py +++ b/py-polars/polars/selectors.py @@ -7,6 +7,9 @@ from typing import TYPE_CHECKING, Any, Collection, Literal, Mapping, overload from polars import functions as F +from polars._utils.deprecation import deprecate_nonkeyword_arguments +from polars._utils.parse_expr_input import _parse_inputs_as_iterable +from polars._utils.various import is_column from polars.datatypes import ( FLOAT_DTYPES, INTEGER_DTYPES, @@ -27,9 +30,6 @@ is_polars_dtype, ) from polars.expr import Expr -from polars.utils._parse_expr_input import _parse_inputs_as_iterable -from polars.utils.deprecation import deprecate_nonkeyword_arguments -from polars.utils.various import is_column if TYPE_CHECKING: import sys @@ -50,8 +50,7 @@ def is_selector(obj: _selector_proxy_) -> Literal[True]: # type: ignore[overloa @overload -def is_selector(obj: Any) -> Literal[False]: - ... +def is_selector(obj: Any) -> Literal[False]: ... def is_selector(obj: Any) -> bool: diff --git a/py-polars/polars/series/array.py b/py-polars/polars/series/array.py index 4a547485f962a..04c88f701575a 100644 --- a/py-polars/polars/series/array.py +++ b/py-polars/polars/series/array.py @@ -3,8 +3,8 @@ from typing import TYPE_CHECKING, Callable, Sequence from polars import functions as F +from polars._utils.wrap import wrap_s from polars.series.utils import expr_dispatch -from polars.utils._wrap import wrap_s if TYPE_CHECKING: from datetime import date, datetime, time diff --git a/py-polars/polars/series/binary.py b/py-polars/polars/series/binary.py index 482773bef65f1..2796ecb403eb5 100644 --- a/py-polars/polars/series/binary.py +++ b/py-polars/polars/series/binary.py @@ -92,7 +92,7 @@ def starts_with(self, prefix: IntoExpr) -> Series: def decode(self, encoding: TransferEncoding, *, strict: bool = True) -> Series: r""" - Decode a value using the provided encoding. + Decode values using the provided encoding. Parameters ---------- @@ -102,8 +102,15 @@ def decode(self, encoding: TransferEncoding, *, strict: bool = True) -> Series: Raise an error if the underlying value cannot be decoded, otherwise mask out with a null value. + Returns + ------- + Series + Series of data type :class:`String`. + Examples -------- + Decode values using hexadecimal encoding. + >>> s = pl.Series("colors", [b"000000", b"ffff00", b"0000ff"]) >>> s.bin.decode("hex") shape: (3,) @@ -113,6 +120,9 @@ def decode(self, encoding: TransferEncoding, *, strict: bool = True) -> Series: b"\xff\xff\x00" b"\x00\x00\xff" ] + + Decode values using Base64 encoding. + >>> s = pl.Series("colors", [b"AAAA", b"//8A", b"AAD/"]) >>> s.bin.decode("base64") shape: (3,) @@ -122,11 +132,23 @@ def decode(self, encoding: TransferEncoding, *, strict: bool = True) -> Series: b"\xff\xff\x00" b"\x00\x00\xff" ] + + Set `strict=False` to set invalid values to null instead of raising an error. + + >>> s = pl.Series("colors", [b"000000", b"ffff00", b"invalid_value"]) + >>> s.bin.decode("hex", strict=False) + shape: (3,) + Series: 'colors' [binary] + [ + b"\x00\x00\x00" + b"\xff\xff\x00" + null + ] """ def encode(self, encoding: TransferEncoding) -> Series: r""" - Encode a value using the provided encoding. + Encode values using the provided encoding. Parameters ---------- @@ -136,10 +158,12 @@ def encode(self, encoding: TransferEncoding) -> Series: Returns ------- Series - Series of data type :class:`Boolean`. + Series of data type :class:`String`. Examples -------- + Encode values using hexadecimal encoding. + >>> s = pl.Series("colors", [b"\x00\x00\x00", b"\xff\xff\x00", b"\x00\x00\xff"]) >>> s.bin.encode("hex") shape: (3,) @@ -149,6 +173,9 @@ def encode(self, encoding: TransferEncoding) -> Series: "ffff00" "0000ff" ] + + Encode values using Base64 encoding. + >>> s.bin.encode("base64") shape: (3,) Series: 'colors' [str] diff --git a/py-polars/polars/series/categorical.py b/py-polars/polars/series/categorical.py index 03057ea81f98c..204a125853a97 100644 --- a/py-polars/polars/series/categorical.py +++ b/py-polars/polars/series/categorical.py @@ -2,10 +2,10 @@ from typing import TYPE_CHECKING +from polars._utils.deprecation import deprecate_function +from polars._utils.unstable import unstable +from polars._utils.wrap import wrap_s from polars.series.utils import expr_dispatch -from polars.utils._wrap import wrap_s -from polars.utils.deprecation import deprecate_function -from polars.utils.unstable import unstable if TYPE_CHECKING: from polars import Series diff --git a/py-polars/polars/series/datetime.py b/py-polars/polars/series/datetime.py index 0cc026133c3ea..1add0fba24e8c 100644 --- a/py-polars/polars/series/datetime.py +++ b/py-polars/polars/series/datetime.py @@ -2,12 +2,12 @@ from typing import TYPE_CHECKING +from polars._utils.convert import to_py_date, to_py_datetime +from polars._utils.deprecation import deprecate_function, deprecate_renamed_function +from polars._utils.unstable import unstable +from polars._utils.wrap import wrap_s from polars.datatypes import Date, Datetime, Duration from polars.series.utils import expr_dispatch -from polars.utils._wrap import wrap_s -from polars.utils.convert import _to_python_date, _to_python_datetime -from polars.utils.deprecation import deprecate_function, deprecate_renamed_function -from polars.utils.unstable import unstable if TYPE_CHECKING: import datetime as dt @@ -81,11 +81,11 @@ def median(self) -> TemporalLiteral | float | None: out = s.median() if out is not None: if s.dtype == Date: - return _to_python_date(int(out)) # type: ignore[arg-type] + return to_py_date(int(out)) # type: ignore[arg-type] elif s.dtype in (Datetime, Duration): return out # type: ignore[return-value] else: - return _to_python_datetime(int(out), s.dtype.time_unit) # type: ignore[arg-type, attr-defined] + return to_py_datetime(int(out), s.dtype.time_unit) # type: ignore[arg-type, attr-defined] return None def mean(self) -> TemporalLiteral | float | None: @@ -105,11 +105,11 @@ def mean(self) -> TemporalLiteral | float | None: out = s.mean() if out is not None: if s.dtype == Date: - return _to_python_date(int(out)) # type: ignore[arg-type] + return to_py_date(int(out)) # type: ignore[arg-type] elif s.dtype in (Datetime, Duration): return out # type: ignore[return-value] else: - return _to_python_datetime(int(out), s.dtype.time_unit) # type: ignore[arg-type, attr-defined] + return to_py_datetime(int(out), s.dtype.time_unit) # type: ignore[arg-type, attr-defined] return None def to_string(self, format: str) -> Series: @@ -1157,6 +1157,7 @@ def replace_time_zone( - `'raise'` (default): raise - `'earliest'`: use the earliest datetime - `'latest'`: use the latest datetime + - `'null'`: set to null Examples -------- @@ -1195,7 +1196,7 @@ def replace_time_zone( │ 2020-07-01 01:00:00 BST ┆ 2020-07-01 01:00:00 CEST │ └─────────────────────────────┴────────────────────────────────┘ - You can use `use_earliest` to deal with ambiguous datetimes: + You can use `ambiguous` to deal with ambiguous datetimes: >>> dates = [ ... "2018-10-28 01:30", diff --git a/py-polars/polars/series/list.py b/py-polars/polars/series/list.py index e7c5eb2f08287..ca89c0ceea2d5 100644 --- a/py-polars/polars/series/list.py +++ b/py-polars/polars/series/list.py @@ -3,12 +3,12 @@ from typing import TYPE_CHECKING, Any, Callable, Sequence from polars import functions as F -from polars.series.utils import expr_dispatch -from polars.utils._wrap import wrap_s -from polars.utils.deprecation import ( +from polars._utils.deprecation import ( deprecate_renamed_function, deprecate_renamed_parameter, ) +from polars._utils.wrap import wrap_s +from polars.series.utils import expr_dispatch if TYPE_CHECKING: from datetime import date, datetime, time diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index a79686b057c5d..7a0328bcba2ce 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -22,6 +22,41 @@ import polars._reexport as pl from polars import functions as F +from polars._utils.construction import ( + arrow_to_pyseries, + dataframe_to_pyseries, + iterable_to_pyseries, + numpy_to_idxs, + numpy_to_pyseries, + pandas_to_pyseries, + sequence_to_pyseries, + series_to_pyseries, +) +from polars._utils.convert import ( + date_to_int, + datetime_to_int, + time_to_int, + timedelta_to_int, +) +from polars._utils.deprecation import ( + deprecate_function, + deprecate_nonkeyword_arguments, + deprecate_renamed_function, + deprecate_renamed_parameter, + issue_deprecation_warning, +) +from polars._utils.unstable import unstable +from polars._utils.various import ( + BUILDING_SPHINX_DOCS, + _is_generator, + no_default, + parse_version, + range_to_slice, + scale_bytes, + sphinx_accessor, + warn_null_comparison, +) +from polars._utils.wrap import wrap_df from polars.datatypes import ( Array, Boolean, @@ -74,41 +109,6 @@ from polars.series.struct import StructNameSpace from polars.series.utils import expr_dispatch, get_ffi_func from polars.slice import PolarsSlice -from polars.utils._construction import ( - arrow_to_pyseries, - dataframe_to_pyseries, - iterable_to_pyseries, - numpy_to_idxs, - numpy_to_pyseries, - pandas_to_pyseries, - sequence_to_pyseries, - series_to_pyseries, -) -from polars.utils._wrap import wrap_df -from polars.utils.convert import ( - _date_to_pl_date, - _datetime_to_pl_timestamp, - _time_to_pl_time, - _timedelta_to_pl_timedelta, -) -from polars.utils.deprecation import ( - deprecate_function, - deprecate_nonkeyword_arguments, - deprecate_renamed_function, - deprecate_renamed_parameter, - issue_deprecation_warning, -) -from polars.utils.unstable import unstable -from polars.utils.various import ( - BUILDING_SPHINX_DOCS, - _is_generator, - no_default, - parse_version, - range_to_slice, - scale_bytes, - sphinx_accessor, - warn_null_comparison, -) with contextlib.suppress(ImportError): # Module not available when building docs from polars.polars import PyDataFrame, PySeries @@ -119,6 +119,9 @@ from hvplot.plotting.core import hvPlotTabularPolars from polars import DataFrame, DataType, Expr + from polars._utils.various import ( + NoDefault, + ) from polars.series._numpy import SeriesView from polars.type_aliases import ( BufferInfo, @@ -140,9 +143,6 @@ SizeUnit, TemporalLiteral, ) - from polars.utils.various import ( - NoDefault, - ) if sys.version_info >= (3, 11): from typing import Self @@ -176,10 +176,22 @@ class Series: One-dimensional data in various forms. Supported are: Sequence, Series, pyarrow Array, and numpy ndarray. dtype : DataType, default None - Polars dtype of the Series data. If not specified, the dtype is inferred. - strict - Throw error on numeric overflow. - nan_to_null + Data type of the resulting Series. If set to `None` (default), the data type is + inferred from the `values` input. The strategy for data type inference depends + on the `strict` parameter: + + - If `strict` is set to True (default), the inferred data type is equal to the + first non-null value, or `Null` if all values are null. + - If `strict` is set to False, the inferred data type is the supertype of the + values, or :class:`Object` if no supertype can be found. **WARNING**: A full + pass over the values is required to determine the supertype. + - If no values were passed, the resulting data type is :class:`Null`. + + strict : bool, default True + Throw an error if any value does not exactly match the given or inferred data + type. If set to `False`, values that do not match the data type are cast to + that data type or, if casting is not possible, set to null instead. + nan_to_null : bool, default False In case a numpy array is used to create this Series, indicate how to deal with np.nan values. (This parameter is a no-op on non-numpy data). dtype_if_empty : DataType, default Null @@ -268,17 +280,18 @@ def __init__( version="0.20.6", ) - # If 'Unknown' treat as None to attempt inference + # If 'Unknown' treat as None to trigger type inference if dtype == Unknown: dtype = None - # Raise early error on invalid dtype - elif ( - dtype is not None - and not is_polars_dtype(dtype) - and py_type_to_dtype(dtype, raise_unmatched=False) is None - ): - msg = f"given dtype: {dtype!r} is not a valid Polars data type and cannot be converted into one" - raise ValueError(msg) + elif dtype is not None and not is_polars_dtype(dtype): + # Raise early error on invalid dtype + if not is_polars_dtype( + pl_dtype := py_type_to_dtype(dtype, raise_unmatched=False) + ): + msg = f"given dtype: {dtype!r} is not a valid Polars data type and cannot be converted into one" + raise ValueError(msg) + else: + dtype = pl_dtype # Handle case where values are passed as the first argument original_name: str | None = None @@ -694,20 +707,20 @@ def _comp(self, other: Any, op: ComparisonOperator) -> Series: else: msg = f"cannot compare datetime.datetime to Series of type {self.dtype}" raise ValueError(msg) - ts = _datetime_to_pl_timestamp(other, time_unit) # type: ignore[arg-type] + ts = datetime_to_int(other, time_unit) # type: ignore[arg-type] f = get_ffi_func(op + "_<>", Int64, self._s) assert f is not None return self._from_pyseries(f(ts)) elif isinstance(other, time) and self.dtype == Time: - d = _time_to_pl_time(other) + d = time_to_int(other) f = get_ffi_func(op + "_<>", Int64, self._s) assert f is not None return self._from_pyseries(f(d)) elif isinstance(other, timedelta) and self.dtype == Duration: time_unit = self.dtype.time_unit # type: ignore[attr-defined] - td = _timedelta_to_pl_timedelta(other, time_unit) # type: ignore[arg-type] + td = timedelta_to_int(other, time_unit) # type: ignore[arg-type] f = get_ffi_func(op + "_<>", Int64, self._s) assert f is not None return self._from_pyseries(f(td)) @@ -716,7 +729,7 @@ def _comp(self, other: Any, op: ComparisonOperator) -> Series: other = Series([other]) elif isinstance(other, date) and self.dtype == Date: - d = _date_to_pl_date(other) + d = date_to_int(other) f = get_ffi_func(op + "_<>", Int32, self._s) assert f is not None return self._from_pyseries(f(d)) @@ -744,8 +757,7 @@ def __eq__(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def __eq__(self, other: Any) -> Series: - ... + def __eq__(self, other: Any) -> Series: ... def __eq__(self, other: Any) -> Series | Expr: warn_null_comparison(other) @@ -758,8 +770,7 @@ def __ne__(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def __ne__(self, other: Any) -> Series: - ... + def __ne__(self, other: Any) -> Series: ... def __ne__(self, other: Any) -> Series | Expr: warn_null_comparison(other) @@ -772,8 +783,7 @@ def __gt__(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def __gt__(self, other: Any) -> Series: - ... + def __gt__(self, other: Any) -> Series: ... def __gt__(self, other: Any) -> Series | Expr: warn_null_comparison(other) @@ -786,8 +796,7 @@ def __lt__(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def __lt__(self, other: Any) -> Series: - ... + def __lt__(self, other: Any) -> Series: ... def __lt__(self, other: Any) -> Series | Expr: warn_null_comparison(other) @@ -800,8 +809,7 @@ def __ge__(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def __ge__(self, other: Any) -> Series: - ... + def __ge__(self, other: Any) -> Series: ... def __ge__(self, other: Any) -> Series | Expr: warn_null_comparison(other) @@ -814,8 +822,7 @@ def __le__(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def __le__(self, other: Any) -> Series: - ... + def __le__(self, other: Any) -> Series: ... def __le__(self, other: Any) -> Series | Expr: warn_null_comparison(other) @@ -828,8 +835,7 @@ def le(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def le(self, other: Any) -> Series: - ... + def le(self, other: Any) -> Series: ... def le(self, other: Any) -> Series | Expr: """Method equivalent of operator expression `series <= other`.""" @@ -840,8 +846,7 @@ def lt(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def lt(self, other: Any) -> Series: - ... + def lt(self, other: Any) -> Series: ... def lt(self, other: Any) -> Series | Expr: """Method equivalent of operator expression `series < other`.""" @@ -852,8 +857,7 @@ def eq(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def eq(self, other: Any) -> Series: - ... + def eq(self, other: Any) -> Series: ... def eq(self, other: Any) -> Series | Expr: """Method equivalent of operator expression `series == other`.""" @@ -864,8 +868,7 @@ def eq_missing(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def eq_missing(self, other: Any) -> Series: - ... + def eq_missing(self, other: Any) -> Series: ... def eq_missing(self, other: Any) -> Series | Expr: """ @@ -913,8 +916,7 @@ def ne(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def ne(self, other: Any) -> Series: - ... + def ne(self, other: Any) -> Series: ... def ne(self, other: Any) -> Series | Expr: """Method equivalent of operator expression `series != other`.""" @@ -925,8 +927,7 @@ def ne_missing(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def ne_missing(self, other: Any) -> Series: - ... + def ne_missing(self, other: Any) -> Series: ... def ne_missing(self, other: Any) -> Series | Expr: """ @@ -974,8 +975,7 @@ def ge(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def ge(self, other: Any) -> Series: - ... + def ge(self, other: Any) -> Series: ... def ge(self, other: Any) -> Series | Expr: """Method equivalent of operator expression `series >= other`.""" @@ -986,8 +986,7 @@ def gt(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def gt(self, other: Any) -> Series: - ... + def gt(self, other: Any) -> Series: ... def gt(self, other: Any) -> Series | Expr: """Method equivalent of operator expression `series > other`.""" @@ -1041,8 +1040,7 @@ def __add__(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def __add__(self, other: Any) -> Self: - ... + def __add__(self, other: Any) -> Self: ... def __add__(self, other: Any) -> Self | DataFrame | Expr: if isinstance(other, str): @@ -1058,8 +1056,7 @@ def __sub__(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def __sub__(self, other: Any) -> Self: - ... + def __sub__(self, other: Any) -> Self: ... def __sub__(self, other: Any) -> Self | Expr: if isinstance(other, pl.Expr): @@ -1071,8 +1068,7 @@ def __truediv__(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def __truediv__(self, other: Any) -> Series: - ... + def __truediv__(self, other: Any) -> Series: ... def __truediv__(self, other: Any) -> Series | Expr: if isinstance(other, pl.Expr): @@ -1092,8 +1088,7 @@ def __floordiv__(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def __floordiv__(self, other: Any) -> Series: - ... + def __floordiv__(self, other: Any) -> Series: ... def __floordiv__(self, other: Any) -> Series | Expr: if isinstance(other, pl.Expr): @@ -1118,8 +1113,7 @@ def __mul__(self, other: DataFrame) -> DataFrame: # type: ignore[overload-overl ... @overload - def __mul__(self, other: Any) -> Series: - ... + def __mul__(self, other: Any) -> Series: ... def __mul__(self, other: Any) -> Series | DataFrame | Expr: if isinstance(other, pl.Expr): @@ -1137,8 +1131,7 @@ def __mod__(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def __mod__(self, other: Any) -> Series: - ... + def __mod__(self, other: Any) -> Series: ... def __mod__(self, other: Any) -> Series | Expr: if isinstance(other, pl.Expr): @@ -1302,15 +1295,13 @@ def _take_with_series(self, s: Series) -> Series: return self._from_pyseries(self._s.take_with_series(s._s)) @overload - def __getitem__(self, item: int) -> Any: - ... + def __getitem__(self, item: int) -> Any: ... @overload def __getitem__( self, item: Series | range | slice | np.ndarray[Any, Any] | list[int], - ) -> Series: - ... + ) -> Series: ... def __getitem__( self, @@ -1430,7 +1421,7 @@ def __array_ufunc__( args.append(arg) elif isinstance(arg, Series): validity_mask &= arg.is_not_null() - args.append(arg._view(ignore_nulls=True)) + args.append(arg.to_physical()._s.to_numpy_view()) else: msg = f"unsupported type {type(arg).__name__!r} for {arg!r}" raise TypeError(msg) @@ -1487,7 +1478,7 @@ def __array_ufunc__( def _repr_html_(self) -> str: """Format output data in HTML for display in Jupyter Notebooks.""" - return self.to_frame()._repr_html_(from_series=True) + return self.to_frame()._repr_html_(_from_series=True) @deprecate_renamed_parameter("row", "index", version="0.19.3") def item(self, index: int | None = None) -> Any: @@ -1605,12 +1596,10 @@ def cbrt(self) -> Series: """ @overload - def any(self, *, ignore_nulls: Literal[True] = ...) -> bool: - ... + def any(self, *, ignore_nulls: Literal[True] = ...) -> bool: ... @overload - def any(self, *, ignore_nulls: bool) -> bool | None: - ... + def any(self, *, ignore_nulls: bool) -> bool | None: ... @deprecate_renamed_parameter("drop_nulls", "ignore_nulls", version="0.19.0") def any(self, *, ignore_nulls: bool = True) -> bool | None: @@ -1650,12 +1639,10 @@ def any(self, *, ignore_nulls: bool = True) -> bool | None: return self._s.any(ignore_nulls=ignore_nulls) @overload - def all(self, *, ignore_nulls: Literal[True] = ...) -> bool: - ... + def all(self, *, ignore_nulls: Literal[True] = ...) -> bool: ... @overload - def all(self, *, ignore_nulls: bool) -> bool | None: - ... + def all(self, *, ignore_nulls: bool) -> bool | None: ... @deprecate_renamed_parameter("drop_nulls", "ignore_nulls", version="0.19.0") def all(self, *, ignore_nulls: bool = True) -> bool | None: @@ -2133,7 +2120,9 @@ def quantile( """ return self._s.quantile(quantile, interpolation) - def to_dummies(self, separator: str = "_") -> DataFrame: + def to_dummies( + self, *, separator: str = "_", drop_first: bool = False + ) -> DataFrame: """ Get dummy/indicator variables. @@ -2141,6 +2130,8 @@ def to_dummies(self, separator: str = "_") -> DataFrame: ---------- separator Separator/delimiter used when generating column names. + drop_first + Remove the first category from the variable being encoded. Examples -------- @@ -2156,8 +2147,20 @@ def to_dummies(self, separator: str = "_") -> DataFrame: │ 0 ┆ 1 ┆ 0 │ │ 0 ┆ 0 ┆ 1 │ └─────┴─────┴─────┘ - """ - return wrap_df(self._s.to_dummies(separator)) + + >>> s.to_dummies(drop_first=True) + shape: (3, 2) + ┌─────┬─────┐ + │ a_2 ┆ a_3 │ + │ --- ┆ --- │ + │ u8 ┆ u8 │ + ╞═════╪═════╡ + │ 0 ┆ 0 │ + │ 1 ┆ 0 │ + │ 0 ┆ 1 │ + └─────┴─────┘ + """ + return wrap_df(self._s.to_dummies(separator, drop_first)) @overload def cut( @@ -2170,8 +2173,7 @@ def cut( left_closed: bool = ..., include_breaks: bool = ..., as_series: Literal[True] = ..., - ) -> Series: - ... + ) -> Series: ... @overload def cut( @@ -2184,8 +2186,7 @@ def cut( left_closed: bool = ..., include_breaks: bool = ..., as_series: Literal[False], - ) -> DataFrame: - ... + ) -> DataFrame: ... @overload def cut( @@ -2198,8 +2199,7 @@ def cut( left_closed: bool = ..., include_breaks: bool = ..., as_series: bool, - ) -> Series | DataFrame: - ... + ) -> Series | DataFrame: ... @deprecate_nonkeyword_arguments(["self", "breaks"], version="0.19.0") @deprecate_renamed_parameter("series", "as_series", version="0.19.0") @@ -2368,8 +2368,7 @@ def qcut( break_point_label: str = ..., category_label: str = ..., as_series: Literal[True] = ..., - ) -> Series: - ... + ) -> Series: ... @overload def qcut( @@ -2383,8 +2382,7 @@ def qcut( break_point_label: str = ..., category_label: str = ..., as_series: Literal[False], - ) -> DataFrame: - ... + ) -> DataFrame: ... @overload def qcut( @@ -2398,8 +2396,7 @@ def qcut( break_point_label: str = ..., category_label: str = ..., as_series: bool, - ) -> Series | DataFrame: - ... + ) -> Series | DataFrame: ... @unstable() def qcut( @@ -3547,16 +3544,16 @@ def arg_max(self) -> int | None: return self._s.arg_max() @overload - def search_sorted(self, element: int | float, side: SearchSortedSide = ...) -> int: - ... + def search_sorted( + self, element: int | float, side: SearchSortedSide = ... + ) -> int: ... @overload def search_sorted( self, element: Series | np.ndarray[Any, Any] | list[int] | list[float], side: SearchSortedSide = ..., - ) -> Series: - ... + ) -> Series: ... def search_sorted( self, @@ -4279,9 +4276,10 @@ def is_between( def to_numpy( self, *, - zero_copy_only: bool = False, + allow_copy: bool = True, writable: bool = False, use_pyarrow: bool = True, + zero_copy_only: bool | None = None, ) -> np.ndarray[Any, Any]: """ Convert this Series to a NumPy ndarray. @@ -4292,14 +4290,13 @@ def to_numpy( - Floating point `nan` values can be zero-copied - Booleans cannot be zero-copied - To ensure that no data is copied, set `zero_copy_only=True`. + To ensure that no data is copied, set `allow_copy=False`. Parameters ---------- - zero_copy_only - Raise an exception if the conversion to a NumPy would require copying - the underlying data. Data copy occurs, for example, when the Series contains - nulls or non-numeric types. + allow_copy + Allow memory to be copied to perform the conversion. If set to `False`, + causes conversions that are not zero-copy to fail. writable Ensure the resulting array is writable. This will force a copy of the data if the array was created without copy, as the underlying Arrow data is @@ -4308,6 +4305,14 @@ def to_numpy( Use `pyarrow.Array.to_numpy `_ for the conversion to NumPy. + zero_copy_only + Raise an exception if the conversion to a NumPy would require copying + the underlying data. Data copy occurs, for example, when the Series contains + nulls or non-numeric types. + + .. deprecated:: 0.20.10 + Use the `allow_copy` parameter instead, which is the inverse of this + one. Examples -------- @@ -4318,9 +4323,16 @@ def to_numpy( >>> type(arr) """ + if zero_copy_only is not None: + issue_deprecation_warning( + "The `zero_copy_only` parameter for `Series.to_numpy` is deprecated." + " Use the `allow_copy` parameter instead, which is the inverse of `zero_copy_only`.", + version="0.20.10", + ) + allow_copy = not zero_copy_only - def raise_no_zero_copy() -> None: - if zero_copy_only and not self.is_empty(): + def raise_on_copy() -> None: + if not allow_copy and not self.is_empty(): msg = "cannot return a zero-copy array" raise ValueError(msg) @@ -4336,14 +4348,14 @@ def temporal_dtype_to_numpy(dtype: PolarsDataType) -> Any: raise TypeError(msg) if self.n_chunks() > 1: - raise_no_zero_copy() + raise_on_copy() self = self.rechunk() dtype = self.dtype if dtype == Array: np_array = self.explode().to_numpy( - zero_copy_only=zero_copy_only, + allow_copy=allow_copy, writable=writable, use_pyarrow=use_pyarrow, ) @@ -4356,72 +4368,42 @@ def temporal_dtype_to_numpy(dtype: PolarsDataType) -> Any: and dtype not in (Object, Datetime, Duration, Date) ): return self.to_arrow().to_numpy( - zero_copy_only=zero_copy_only, writable=writable + zero_copy_only=not allow_copy, writable=writable ) if self.null_count() == 0: if dtype.is_integer() or dtype.is_float(): - np_array = self._view(ignore_nulls=True) + np_array = self._s.to_numpy_view() elif dtype == Boolean: - raise_no_zero_copy() - np_array = self.cast(UInt8)._view(ignore_nulls=True).view(bool) + raise_on_copy() + s_u8 = self.cast(UInt8) + np_array = s_u8._s.to_numpy_view().view(bool) elif dtype in (Datetime, Duration): np_dtype = temporal_dtype_to_numpy(dtype) - np_array = self._view(ignore_nulls=True).view(np_dtype) + s_i64 = self.to_physical() + np_array = s_i64._s.to_numpy_view().view(np_dtype) elif dtype == Date: - raise_no_zero_copy() + raise_on_copy() np_dtype = temporal_dtype_to_numpy(dtype) - np_array = self.to_physical()._view(ignore_nulls=True).astype(np_dtype) + s_i32 = self.to_physical() + np_array = s_i32._s.to_numpy_view().astype(np_dtype) else: - raise_no_zero_copy() + raise_on_copy() np_array = self._s.to_numpy() else: - raise_no_zero_copy() + raise_on_copy() np_array = self._s.to_numpy() if dtype in (Datetime, Duration, Date): np_dtype = temporal_dtype_to_numpy(dtype) np_array = np_array.view(np_dtype) if writable and not np_array.flags.writeable: - raise_no_zero_copy() + raise_on_copy() np_array = np_array.copy() return np_array - def _view(self, *, ignore_nulls: bool = False) -> SeriesView: - """ - Get a view into this Series data with a numpy array. - - This operation doesn't clone data, but does not include missing values. - - Returns - ------- - SeriesView - - Parameters - ---------- - ignore_nulls - If True then nulls are converted to 0. - If False then an Exception is raised if nulls are present. - - Examples - -------- - >>> s = pl.Series("a", [1, None]) - >>> s._view(ignore_nulls=True) - SeriesView([1, 0]) - """ - if not ignore_nulls: - assert not self.null_count() - - from polars.series._numpy import SeriesView, _ptr_to_numpy - - ptr_type = dtype_to_ctype(self.dtype) - ptr = self._s.as_single_ptr() - array = _ptr_to_numpy(ptr, self.len(), ptr_type) - array.setflags(write=False) - return SeriesView(array, self) - def to_arrow(self) -> pa.Array: """ Return the underlying Arrow array. @@ -5302,7 +5284,7 @@ def map_elements( ------- Series """ - from polars.utils.udfs import warn_on_inefficient_map + from polars._utils.udfs import warn_on_inefficient_map if return_dtype is None: pl_return_dtype = None @@ -6475,6 +6457,16 @@ def kurtosis(self, *, fisher: bool = True, bias: bool = True) -> float | None: Pearson's definition is used (normal ==> 3.0). bias : bool, optional If False, the calculations are corrected for statistical bias. + + Examples + -------- + >>> s = pl.Series("grades", [66, 79, 54, 97, 96, 70, 69, 85, 93, 75]) + >>> s.kurtosis() + -1.0522623626787952 + >>> s.kurtosis(fisher=False) + 1.9477376373212048 + >>> s.kurtosis(fisher=False, bias=False) + 2.104036180264273 """ return self._s.kurtosis(fisher, bias) @@ -6604,7 +6596,7 @@ def replace( old Value or sequence of values to replace. Also accepts a mapping of values to their replacement as syntactic sugar for - `replace(new=Series(mapping.keys()), old=Series(mapping.values()))`. + `replace(old=Series(mapping.keys()), new=Series(mapping.values()))`. new Value or sequence of values to replace by. Length must match the length of `old` or have length 1. @@ -6783,7 +6775,7 @@ def ewm_mean( *, adjust: bool = True, min_periods: int = 1, - ignore_nulls: bool = True, + ignore_nulls: bool | None = None, ) -> Series: r""" Exponentially-weighted moving average. @@ -6812,7 +6804,7 @@ def ewm_mean( Divide by decaying adjustment factor in beginning periods to account for imbalance in relative weightings - - When `adjust=True` the EW function is calculated + - When `adjust=True` (the default) the EW function is calculated using weights :math:`w_i = (1 - \alpha)^i` - When `adjust=False` the EW function is calculated recursively by @@ -6826,7 +6818,7 @@ def ewm_mean( ignore_nulls Ignore missing values when calculating weights. - - When `ignore_nulls=False` (default), weights are based on absolute + - When `ignore_nulls=False`, weights are based on absolute positions. For example, the weights of :math:`x_0` and :math:`x_2` used in calculating the final weighted average of @@ -6834,7 +6826,7 @@ def ewm_mean( :math:`(1-\alpha)^2` and :math:`1` if `adjust=True`, and :math:`(1-\alpha)^2` and :math:`\alpha` if `adjust=False`. - - When `ignore_nulls=True`, weights are based + - When `ignore_nulls=True` (current default), weights are based on relative positions. For example, the weights of :math:`x_0` and :math:`x_2` used in calculating the final weighted average of [:math:`x_0`, None, :math:`x_2`] are @@ -6844,7 +6836,7 @@ def ewm_mean( Examples -------- >>> s = pl.Series([1, 2, 3]) - >>> s.ewm_mean(com=1) + >>> s.ewm_mean(com=1, ignore_nulls=False) shape: (3,) Series: '' [f64] [ @@ -6865,7 +6857,7 @@ def ewm_std( adjust: bool = True, bias: bool = False, min_periods: int = 1, - ignore_nulls: bool = True, + ignore_nulls: bool | None = None, ) -> Series: r""" Exponentially-weighted moving standard deviation. @@ -6894,7 +6886,7 @@ def ewm_std( Divide by decaying adjustment factor in beginning periods to account for imbalance in relative weightings - - When `adjust=True` the EW function is calculated + - When `adjust=True` (the default) the EW function is calculated using weights :math:`w_i = (1 - \alpha)^i` - When `adjust=False` the EW function is calculated recursively by @@ -6911,7 +6903,7 @@ def ewm_std( ignore_nulls Ignore missing values when calculating weights. - - When `ignore_nulls=False` (default), weights are based on absolute + - When `ignore_nulls=False`, weights are based on absolute positions. For example, the weights of :math:`x_0` and :math:`x_2` used in calculating the final weighted average of @@ -6919,7 +6911,7 @@ def ewm_std( :math:`(1-\alpha)^2` and :math:`1` if `adjust=True`, and :math:`(1-\alpha)^2` and :math:`\alpha` if `adjust=False`. - - When `ignore_nulls=True`, weights are based + - When `ignore_nulls=True` (current default), weights are based on relative positions. For example, the weights of :math:`x_0` and :math:`x_2` used in calculating the final weighted average of [:math:`x_0`, None, :math:`x_2`] are @@ -6929,7 +6921,7 @@ def ewm_std( Examples -------- >>> s = pl.Series("a", [1, 2, 3]) - >>> s.ewm_std(com=1) + >>> s.ewm_std(com=1, ignore_nulls=False) shape: (3,) Series: 'a' [f64] [ @@ -6950,7 +6942,7 @@ def ewm_var( adjust: bool = True, bias: bool = False, min_periods: int = 1, - ignore_nulls: bool = True, + ignore_nulls: bool | None = None, ) -> Series: r""" Exponentially-weighted moving variance. @@ -6979,7 +6971,7 @@ def ewm_var( Divide by decaying adjustment factor in beginning periods to account for imbalance in relative weightings - - When `adjust=True` the EW function is calculated + - When `adjust=True` (the default) the EW function is calculated using weights :math:`w_i = (1 - \alpha)^i` - When `adjust=False` the EW function is calculated recursively by @@ -6996,7 +6988,7 @@ def ewm_var( ignore_nulls Ignore missing values when calculating weights. - - When `ignore_nulls=False` (default), weights are based on absolute + - When `ignore_nulls=False`, weights are based on absolute positions. For example, the weights of :math:`x_0` and :math:`x_2` used in calculating the final weighted average of @@ -7004,7 +6996,7 @@ def ewm_var( :math:`(1-\alpha)^2` and :math:`1` if `adjust=True`, and :math:`(1-\alpha)^2` and :math:`\alpha` if `adjust=False`. - - When `ignore_nulls=True`, weights are based + - When `ignore_nulls=True` (current default), weights are based on relative positions. For example, the weights of :math:`x_0` and :math:`x_2` used in calculating the final weighted average of [:math:`x_0`, None, :math:`x_2`] are @@ -7014,7 +7006,7 @@ def ewm_var( Examples -------- >>> s = pl.Series("a", [1, 2, 3]) - >>> s.ewm_var(com=1) + >>> s.ewm_var(com=1, ignore_nulls=False) shape: (3,) Series: 'a' [f64] [ @@ -7520,7 +7512,7 @@ def cumprod(self, *, reverse: bool = False) -> Series: return self.cum_prod(reverse=reverse) @deprecate_function( - "Use `Series.to_numpy(zero_copy_only=True) instead.", version="0.19.14" + "Use `Series.to_numpy(allow_copy=False) instead.", version="0.19.14" ) def view(self, *, ignore_nulls: bool = False) -> SeriesView: """ @@ -7538,7 +7530,16 @@ def view(self, *, ignore_nulls: bool = False) -> SeriesView: If True then nulls are converted to 0. If False then an Exception is raised if nulls are present. """ - return self._view(ignore_nulls=ignore_nulls) + if not ignore_nulls: + assert not self.null_count() + + from polars.series._numpy import SeriesView, _ptr_to_numpy + + ptr_type = dtype_to_ctype(self.dtype) + ptr = self._s.as_single_ptr() + array = _ptr_to_numpy(ptr, self.len(), ptr_type) + array.setflags(write=False) + return SeriesView(array, self) @deprecate_function( "It has been renamed to `replace`." diff --git a/py-polars/polars/series/string.py b/py-polars/polars/series/string.py index 9b70d177a350f..8aad10f4dcf42 100644 --- a/py-polars/polars/series/string.py +++ b/py-polars/polars/series/string.py @@ -2,12 +2,12 @@ from typing import TYPE_CHECKING -from polars.datatypes.constants import N_INFER_DEFAULT -from polars.series.utils import expr_dispatch -from polars.utils.deprecation import ( +from polars._utils.deprecation import ( deprecate_renamed_function, deprecate_renamed_parameter, ) +from polars.datatypes.constants import N_INFER_DEFAULT +from polars.series.utils import expr_dispatch if TYPE_CHECKING: from polars import Expr, Series @@ -140,6 +140,7 @@ def to_datetime( - `'raise'` (default): raise - `'earliest'`: use the earliest datetime - `'latest'`: use the latest datetime + - `'null'`: set to null Examples -------- @@ -237,6 +238,7 @@ def strptime( - `'raise'` (default): raise - `'earliest'`: use the earliest datetime - `'latest'`: use the latest datetime + - `'null'`: set to null Notes ----- @@ -635,8 +637,8 @@ def starts_with(self, prefix: str | Expr) -> Series: """ def decode(self, encoding: TransferEncoding, *, strict: bool = True) -> Series: - """ - Decode a value using the provided encoding. + r""" + Decode values using the provided encoding. Parameters ---------- @@ -645,6 +647,23 @@ def decode(self, encoding: TransferEncoding, *, strict: bool = True) -> Series: strict Raise an error if the underlying value cannot be decoded, otherwise mask out with a null value. + + Returns + ------- + Series + Series of data type :class:`Binary`. + + Examples + -------- + >>> s = pl.Series("color", ["000000", "ffff00", "0000ff"]) + >>> s.str.decode("hex") + shape: (3,) + Series: 'color' [binary] + [ + b"\x00\x00\x00" + b"\xff\xff\x00" + b"\x00\x00\xff" + ] """ def encode(self, encoding: TransferEncoding) -> Series: diff --git a/py-polars/polars/series/struct.py b/py-polars/polars/series/struct.py index b0fe9f4e22b9f..ceb843078c153 100644 --- a/py-polars/polars/series/struct.py +++ b/py-polars/polars/series/struct.py @@ -3,9 +3,9 @@ from collections import OrderedDict from typing import TYPE_CHECKING, Sequence +from polars._utils.various import BUILDING_SPHINX_DOCS, sphinx_accessor +from polars._utils.wrap import wrap_df from polars.series.utils import expr_dispatch -from polars.utils._wrap import wrap_df -from polars.utils.various import BUILDING_SPHINX_DOCS, sphinx_accessor if TYPE_CHECKING: from polars import DataFrame, DataType, Series @@ -37,7 +37,15 @@ def _ipython_key_completions_(self) -> list[str]: @property def fields(self) -> list[str]: - """Get the names of the fields.""" + """ + Get the names of the fields. + + Examples + -------- + >>> s = pl.Series([{"a": 1, "b": 2}, {"a": 3, "b": 4}]) + >>> s.struct.fields + ['a', 'b'] + """ if getattr(self, "_s", None) is None: return [] return self._s.struct_fields() @@ -49,7 +57,18 @@ def field(self, name: str) -> Series: Parameters ---------- name - Name of the field + Name of the field. + + Examples + -------- + >>> s = pl.Series([{"a": 1, "b": 2}, {"a": 3, "b": 4}]) + >>> s.struct.field("a") + shape: (2,) + Series: 'a' [i64] + [ + 1 + 3 + ] """ def rename_fields(self, names: Sequence[str]) -> Series: @@ -59,12 +78,29 @@ def rename_fields(self, names: Sequence[str]) -> Series: Parameters ---------- names - New names in the order of the struct's fields + New names in the order of the struct's fields. + + Examples + -------- + >>> s = pl.Series([{"a": 1, "b": 2}, {"a": 3, "b": 4}]) + >>> s.struct.fields + ['a', 'b'] + >>> s = s.struct.rename_fields(["c", "d"]) + >>> s.struct.fields + ['c', 'd'] """ @property def schema(self) -> OrderedDict[str, DataType]: - """Get the struct definition as a name/dtype schema dict.""" + """ + Get the struct definition as a name/dtype schema dict. + + Examples + -------- + >>> s = pl.Series([{"a": 1, "b": 2}, {"a": 3, "b": 4}]) + >>> s.struct.schema + OrderedDict({'a': Int64, 'b': Int64}) + """ if getattr(self, "_s", None) is None: return OrderedDict() return OrderedDict(self._s.dtype().to_schema()) diff --git a/py-polars/polars/series/utils.py b/py-polars/polars/series/utils.py index fb2f1440fb7a2..237b55a396dac 100644 --- a/py-polars/polars/series/utils.py +++ b/py-polars/polars/series/utils.py @@ -7,8 +7,8 @@ import polars._reexport as pl from polars import functions as F +from polars._utils.wrap import wrap_s from polars.datatypes import dtype_to_ffiname -from polars.utils._wrap import wrap_s if TYPE_CHECKING: from polars import Series diff --git a/py-polars/polars/sql/context.py b/py-polars/polars/sql/context.py index afbd1dcea4c1a..db02e51a1dd82 100644 --- a/py-polars/polars/sql/context.py +++ b/py-polars/polars/sql/context.py @@ -3,12 +3,12 @@ import contextlib from typing import TYPE_CHECKING, Collection, Generic, Mapping, overload +from polars._utils.unstable import issue_unstable_warning +from polars._utils.various import _get_stack_locals +from polars._utils.wrap import wrap_ldf from polars.dataframe import DataFrame from polars.lazyframe import LazyFrame from polars.type_aliases import FrameType -from polars.utils._wrap import wrap_ldf -from polars.utils.unstable import issue_unstable_warning -from polars.utils.various import _get_stack_locals with contextlib.suppress(ImportError): # Module not available when building docs from polars.polars import PySQLContext @@ -52,8 +52,7 @@ def __init__( register_globals: bool | int = ..., eager_execution: Literal[False] = False, **named_frames: DataFrame | LazyFrame | None, - ) -> None: - ... + ) -> None: ... @overload def __init__( @@ -63,8 +62,7 @@ def __init__( register_globals: bool | int = ..., eager_execution: Literal[True], **named_frames: DataFrame | LazyFrame | None, - ) -> None: - ... + ) -> None: ... def __init__( self, @@ -162,38 +160,32 @@ def __repr__(self) -> str: @overload def execute( self: SQLContext[DataFrame], query: str, eager: None = ... - ) -> DataFrame: - ... + ) -> DataFrame: ... @overload def execute( self: SQLContext[DataFrame], query: str, eager: Literal[False] - ) -> LazyFrame: - ... + ) -> LazyFrame: ... @overload def execute( self: SQLContext[DataFrame], query: str, eager: Literal[True] - ) -> DataFrame: - ... + ) -> DataFrame: ... @overload def execute( self: SQLContext[LazyFrame], query: str, eager: None = ... - ) -> LazyFrame: - ... + ) -> LazyFrame: ... @overload def execute( self: SQLContext[LazyFrame], query: str, eager: Literal[False] - ) -> LazyFrame: - ... + ) -> LazyFrame: ... @overload def execute( self: SQLContext[LazyFrame], query: str, eager: Literal[True] - ) -> DataFrame: - ... + ) -> DataFrame: ... def execute(self, query: str, eager: bool | None = None) -> LazyFrame | DataFrame: """ diff --git a/py-polars/polars/string_cache.py b/py-polars/polars/string_cache.py index dbf15d6244e84..6955d7da5e8e9 100644 --- a/py-polars/polars/string_cache.py +++ b/py-polars/polars/string_cache.py @@ -3,7 +3,7 @@ import contextlib from typing import TYPE_CHECKING -from polars.utils.deprecation import issue_deprecation_warning +from polars._utils.deprecation import issue_deprecation_warning with contextlib.suppress(ImportError): # Module not available when building docs import polars.polars as plr diff --git a/py-polars/polars/testing/parametric/primitives.py b/py-polars/polars/testing/parametric/primitives.py index 2705723965c78..bf5c30e19c884 100644 --- a/py-polars/polars/testing/parametric/primitives.py +++ b/py-polars/polars/testing/parametric/primitives.py @@ -461,8 +461,7 @@ def dataframes( allow_infinities: bool = True, allowed_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, excluded_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, -) -> SearchStrategy[DataFrame]: - ... +) -> SearchStrategy[DataFrame]: ... @overload @@ -481,8 +480,7 @@ def dataframes( allow_infinities: bool = True, allowed_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, excluded_dtypes: Collection[PolarsDataType] | PolarsDataType | None = None, -) -> SearchStrategy[LazyFrame]: - ... +) -> SearchStrategy[LazyFrame]: ... @defines_strategy() diff --git a/py-polars/polars/testing/parametric/profiles.py b/py-polars/polars/testing/parametric/profiles.py index c1afda1c2bfe1..76682af6d7f2d 100644 --- a/py-polars/polars/testing/parametric/profiles.py +++ b/py-polars/polars/testing/parametric/profiles.py @@ -5,8 +5,8 @@ from hypothesis import settings +from polars._utils.deprecation import deprecate_nonkeyword_arguments from polars.type_aliases import ParametricProfileNames -from polars.utils.deprecation import deprecate_nonkeyword_arguments @deprecate_nonkeyword_arguments(allowed_args=["profile"], version="0.19.3") diff --git a/py-polars/polars/type_aliases.py b/py-polars/polars/type_aliases.py index eee5e670d8a67..e7a02ef611974 100644 --- a/py-polars/polars/type_aliases.py +++ b/py-polars/polars/type_aliases.py @@ -23,6 +23,7 @@ import sys from sqlalchemy import Engine + from sqlalchemy.orm import Session from polars import DataFrame, Expr, LazyFrame, Series from polars.datatypes import DataType, DataTypeClass, IntegerType, TemporalType @@ -147,7 +148,7 @@ ] # ListToStructWidthStrategy # The following have no equivalent on the Rust side -Ambiguous: TypeAlias = Literal["earliest", "latest", "raise"] +Ambiguous: TypeAlias = Literal["earliest", "latest", "raise", "null"] ConcatMethod = Literal[ "vertical", "vertical_relaxed", @@ -227,17 +228,11 @@ class SeriesBuffers(TypedDict): # minimal protocol definitions that can reasonably represent # an executable connection, cursor, or equivalent object class BasicConnection(Protocol): # noqa: D101 - def close(self) -> None: - """Close the connection.""" - def cursor(self, *args: Any, **kwargs: Any) -> Any: """Return a cursor object.""" class BasicCursor(Protocol): # noqa: D101 - def close(self) -> None: - """Close the cursor.""" - def execute(self, *args: Any, **kwargs: Any) -> Any: """Execute a query.""" @@ -250,4 +245,4 @@ def fetchmany(self, *args: Any, **kwargs: Any) -> Any: """Fetch results in batches.""" -ConnectionOrCursor = Union[BasicConnection, BasicCursor, Cursor, "Engine"] +ConnectionOrCursor = Union[BasicConnection, BasicCursor, Cursor, "Engine", "Session"] diff --git a/py-polars/polars/utils/__init__.py b/py-polars/polars/utils/__init__.py index 133bca13981b1..6efa5aaa60c0e 100644 --- a/py-polars/polars/utils/__init__.py +++ b/py-polars/polars/utils/__init__.py @@ -1,38 +1 @@ -""" -Utility functions. - -Functions that are part of the public API are re-exported here. -""" -from polars.utils._scan import _execute_from_rust -from polars.utils.convert import ( - _date_to_pl_date, - _datetime_for_any_value, - _datetime_for_any_value_windows, - _time_to_pl_time, - _timedelta_to_pl_timedelta, - _to_python_date, - _to_python_datetime, - _to_python_decimal, - _to_python_time, - _to_python_timedelta, -) -from polars.utils.various import NoDefault, _polars_warn, is_column, no_default - -__all__ = [ - "NoDefault", - "is_column", - "no_default", - # Required for Rust bindings - "_date_to_pl_date", - "_datetime_for_any_value", - "_datetime_for_any_value_windows", - "_execute_from_rust", - "_polars_warn", - "_time_to_pl_time", - "_timedelta_to_pl_timedelta", - "_to_python_date", - "_to_python_datetime", - "_to_python_decimal", - "_to_python_time", - "_to_python_timedelta", -] +"""Deprecated module. Do not use.""" diff --git a/py-polars/polars/utils/convert.py b/py-polars/polars/utils/convert.py deleted file mode 100644 index 2e963cf04a347..0000000000000 --- a/py-polars/polars/utils/convert.py +++ /dev/null @@ -1,269 +0,0 @@ -from __future__ import annotations - -import sys -from datetime import datetime, time, timedelta, timezone -from decimal import Context -from functools import lru_cache -from typing import TYPE_CHECKING, Any, Callable, Sequence, TypeVar, overload - -from polars.dependencies import _ZONEINFO_AVAILABLE, zoneinfo - -if TYPE_CHECKING: - from collections.abc import Reversible - from datetime import date, tzinfo - from decimal import Decimal - - from polars.type_aliases import TimeUnit - - if sys.version_info >= (3, 10): - from typing import ParamSpec - else: - from typing_extensions import ParamSpec - - P = ParamSpec("P") - T = TypeVar("T") - - # the below shenanigans with ZoneInfo are all to handle a - # typing issue in py < 3.9 while preserving lazy-loading - if sys.version_info >= (3, 9): - from zoneinfo import ZoneInfo - elif _ZONEINFO_AVAILABLE: - from backports.zoneinfo._zoneinfo import ZoneInfo - - def get_zoneinfo(key: str) -> ZoneInfo: # noqa: D103 - pass - -else: - - @lru_cache(None) - def get_zoneinfo(key: str) -> ZoneInfo: # noqa: D103 - return zoneinfo.ZoneInfo(key) - - -# note: reversed views don't match as instances of MappingView -if sys.version_info >= (3, 11): - _views: list[Reversible[Any]] = [{}.keys(), {}.values(), {}.items()] - _reverse_mapping_views = tuple(type(reversed(view)) for view in _views) - -SECONDS_PER_DAY = 86_400 -SECONDS_PER_HOUR = 3_600 -NS_PER_SECOND = 1_000_000_000 -US_PER_SECOND = 1_000_000 -MS_PER_SECOND = 1_000 - -EPOCH = datetime(1970, 1, 1).replace(tzinfo=None) -EPOCH_UTC = datetime(1970, 1, 1, tzinfo=timezone.utc) - - -def _timestamp_in_seconds(dt: datetime) -> int: - du = dt - EPOCH_UTC - return du.days * SECONDS_PER_DAY + du.seconds - - -@overload -def _timedelta_to_pl_duration(td: None) -> None: - ... - - -@overload -def _timedelta_to_pl_duration(td: timedelta | str) -> str: - ... - - -def _timedelta_to_pl_duration(td: timedelta | str | None) -> str | None: - """Convert python timedelta to a polars duration string.""" - if td is None or isinstance(td, str): - return td - - if td.days >= 0: - d = td.days and f"{td.days}d" or "" - s = td.seconds and f"{td.seconds}s" or "" - us = td.microseconds and f"{td.microseconds}us" or "" - else: - if not td.seconds and not td.microseconds: - d = td.days and f"{td.days}d" or "" - s = "" - us = "" - else: - corrected_d = td.days + 1 - d = corrected_d and f"{corrected_d}d" or "-" - corrected_seconds = SECONDS_PER_DAY - (td.seconds + (td.microseconds > 0)) - s = corrected_seconds and f"{corrected_seconds}s" or "" - us = td.microseconds and f"{10**6 - td.microseconds}us" or "" - - return f"{d}{s}{us}" - - -def _negate_duration(duration: str) -> str: - if duration.startswith("-"): - return duration[1:] - return f"-{duration}" - - -def _time_to_pl_time(t: time) -> int: - t = t.replace(tzinfo=timezone.utc) - seconds = t.hour * SECONDS_PER_HOUR + t.minute * 60 + t.second - microseconds = t.microsecond - return seconds * NS_PER_SECOND + microseconds * 1_000 - - -def _date_to_pl_date(d: date) -> int: - dt = datetime.combine(d, datetime.min.time()).replace(tzinfo=timezone.utc) - return int(dt.timestamp()) // SECONDS_PER_DAY - - -def _datetime_to_pl_timestamp(dt: datetime, time_unit: TimeUnit | None) -> int: - """Convert a python datetime to a timestamp in given time unit.""" - if dt.tzinfo is None: - # Make sure to use UTC rather than system time zone. - dt = dt.replace(tzinfo=timezone.utc) - microseconds = dt.microsecond - seconds = _timestamp_in_seconds(dt) - if time_unit == "ns": - return seconds * NS_PER_SECOND + microseconds * 1_000 - elif time_unit == "us" or time_unit is None: - return seconds * US_PER_SECOND + microseconds - elif time_unit == "ms": - return seconds * MS_PER_SECOND + microseconds // 1_000 - msg = f"`time_unit` must be one of {{'ms', 'us', 'ns'}}, got {time_unit!r}" - raise ValueError(msg) - - -def _timedelta_to_pl_timedelta(td: timedelta, time_unit: TimeUnit | None) -> int: - """Convert a Python timedelta object to a total number of subseconds.""" - microseconds = td.microseconds - seconds = td.days * SECONDS_PER_DAY + td.seconds - if time_unit == "ns": - return seconds * NS_PER_SECOND + microseconds * 1_000 - elif time_unit == "us" or time_unit is None: - return seconds * US_PER_SECOND + microseconds - elif time_unit == "ms": - return seconds * MS_PER_SECOND + microseconds // 1_000 - - -def _to_python_time(value: int) -> time: - """Convert polars int64 (ns) timestamp to python time object.""" - if value == 0: - return time(microsecond=0) - else: - seconds, nanoseconds = divmod(value, NS_PER_SECOND) - minutes, seconds = divmod(seconds, 60) - hours, minutes = divmod(minutes, 60) - return time( - hour=hours, minute=minutes, second=seconds, microsecond=nanoseconds // 1_000 - ) - - -def _to_python_timedelta( - value: int | float, time_unit: TimeUnit | None = "ns" -) -> timedelta: - if time_unit == "ns": - return timedelta(microseconds=value // 1_000) - elif time_unit == "us": - return timedelta(microseconds=value) - elif time_unit == "ms": - return timedelta(milliseconds=value) - else: - msg = f"`time_unit` must be one of {{'ns', 'us', 'ms'}}, got {time_unit!r}" - raise ValueError(msg) - - -@lru_cache(256) -def _to_python_date(value: int | float) -> date: - """Convert polars int64 timestamp to Python date.""" - return (EPOCH_UTC + timedelta(seconds=value * SECONDS_PER_DAY)).date() - - -def _to_python_datetime( - value: int | float, - time_unit: TimeUnit | None = "ns", - time_zone: str | None = None, -) -> datetime: - """Convert polars int64 timestamp to Python datetime.""" - if not time_zone: - if time_unit == "us": - return EPOCH + timedelta(microseconds=value) - elif time_unit == "ns": - return EPOCH + timedelta(microseconds=value // 1_000) - elif time_unit == "ms": - return EPOCH + timedelta(milliseconds=value) - else: - msg = f"`time_unit` must be one of {{'ns', 'us', 'ms'}}, got {time_unit!r}" - raise ValueError(msg) - elif _ZONEINFO_AVAILABLE: - if time_unit == "us": - dt = EPOCH_UTC + timedelta(microseconds=value) - elif time_unit == "ns": - dt = EPOCH_UTC + timedelta(microseconds=value // 1_000) - elif time_unit == "ms": - dt = EPOCH_UTC + timedelta(milliseconds=value) - else: - msg = f"`time_unit` must be one of {{'ns', 'us', 'ms'}}, got {time_unit!r}" - raise ValueError(msg) - return _localize(dt, time_zone) - else: - msg = "install polars[timezone] to handle datetimes with time zone information" - raise ImportError(msg) - - -def _localize(dt: datetime, time_zone: str) -> datetime: - # zone info installation should already be checked - _tzinfo: ZoneInfo | tzinfo - try: - _tzinfo = get_zoneinfo(time_zone) - except zoneinfo.ZoneInfoNotFoundError: - # try fixed offset, which is not supported by ZoneInfo - _tzinfo = _parse_fixed_tz_offset(time_zone) - - return dt.astimezone(_tzinfo) - - -def _datetime_for_any_value(dt: datetime) -> tuple[int, int]: - """Used in PyO3 AnyValue conversion.""" - # returns (s, ms) - if dt.tzinfo is None: - return ( - _timestamp_in_seconds(dt.replace(tzinfo=timezone.utc)), - dt.microsecond, - ) - return (_timestamp_in_seconds(dt), dt.microsecond) - - -def _datetime_for_any_value_windows(dt: datetime) -> tuple[float, int]: - """Used in PyO3 AnyValue conversion.""" - if dt.tzinfo is None: - dt = _localize(dt, "UTC") - # returns (s, ms) - return (_timestamp_in_seconds(dt), dt.microsecond) - - -# cache here as we have a single tz per column -# and this function will be called on every conversion -@lru_cache(16) -def _parse_fixed_tz_offset(offset: str) -> tzinfo: - try: - # use fromisoformat to parse the offset - dt_offset = datetime.fromisoformat("2000-01-01T00:00:00" + offset) - - # alternatively, we parse the offset ourselves extracting hours and - # minutes, then we can construct: - # tzinfo=timezone(timedelta(hours=..., minutes=...)) - except ValueError: - msg = f"offset: {offset!r} not understood" - raise ValueError(msg) from None - - return dt_offset.tzinfo # type: ignore[return-value] - - -def _to_python_decimal( - sign: int, digits: Sequence[int], prec: int, scale: int -) -> Decimal: - return _create_decimal_with_prec(prec)((sign, digits, scale)) - - -@lru_cache(None) -def _create_decimal_with_prec( - precision: int, -) -> Callable[[tuple[int, Sequence[int], int]], Decimal]: - # pre-cache contexts so we don't have to spend time on recreating them every time - return Context(prec=precision).create_decimal diff --git a/py-polars/polars/utils/udfs.py b/py-polars/polars/utils/udfs.py index 6c98f053c4525..f7ee74e0eece4 100644 --- a/py-polars/polars/utils/udfs.py +++ b/py-polars/polars/utils/udfs.py @@ -1,921 +1,43 @@ -"""Utilities related to user defined functions (such as those passed to `apply`).""" -from __future__ import annotations - -import datetime -import dis -import inspect -import re -import sys -import warnings -from bisect import bisect_left -from collections import defaultdict -from dis import get_instructions -from inspect import signature -from itertools import count, zip_longest -from pathlib import Path -from typing import ( - TYPE_CHECKING, - AbstractSet, - Any, - Callable, - ClassVar, - Iterator, - Literal, - NamedTuple, - Union, -) - -if TYPE_CHECKING: - from dis import Instruction - - if sys.version_info >= (3, 10): - from typing import TypeAlias - else: - from typing_extensions import TypeAlias +"""Deprecated module. Do not use.""" +import os +from typing import Any -class StackValue(NamedTuple): - operator: str - operator_arity: int - left_operand: str - right_operand: str +from polars._utils.deprecation import deprecate_function +__all__ = ["_get_shared_lib_location"] -MapTarget: TypeAlias = Literal["expr", "frame", "series"] -StackEntry: TypeAlias = Union[str, StackValue] -_MIN_PY311 = sys.version_info >= (3, 11) -_MIN_PY312 = _MIN_PY311 and sys.version_info >= (3, 12) - - -class OpNames: - BINARY: ClassVar[dict[str, str]] = { - "BINARY_ADD": "+", - "BINARY_AND": "&", - "BINARY_FLOOR_DIVIDE": "//", - "BINARY_LSHIFT": "<<", - "BINARY_RSHIFT": ">>", - "BINARY_MODULO": "%", - "BINARY_MULTIPLY": "*", - "BINARY_OR": "|", - "BINARY_POWER": "**", - "BINARY_SUBTRACT": "-", - "BINARY_TRUE_DIVIDE": "/", - "BINARY_XOR": "^", - } - CALL = frozenset({"CALL"} if _MIN_PY311 else {"CALL_FUNCTION", "CALL_METHOD"}) - CONTROL_FLOW: ClassVar[dict[str, str]] = ( - { - "POP_JUMP_FORWARD_IF_FALSE": "&", - "POP_JUMP_FORWARD_IF_TRUE": "|", - "JUMP_IF_FALSE_OR_POP": "&", - "JUMP_IF_TRUE_OR_POP": "|", - } - # note: 3.12 dropped POP_JUMP_FORWARD_IF_* opcodes - if _MIN_PY311 and not _MIN_PY312 - else { - "POP_JUMP_IF_FALSE": "&", - "POP_JUMP_IF_TRUE": "|", - "JUMP_IF_FALSE_OR_POP": "&", - "JUMP_IF_TRUE_OR_POP": "|", - } - ) - LOAD_VALUES = frozenset(("LOAD_CONST", "LOAD_DEREF", "LOAD_FAST", "LOAD_GLOBAL")) - LOAD_ATTR = frozenset({"LOAD_METHOD", "LOAD_ATTR"}) - LOAD = LOAD_VALUES | LOAD_ATTR - SYNTHETIC: ClassVar[dict[str, int]] = { - "POLARS_EXPRESSION": 1, - } - UNARY: ClassVar[dict[str, str]] = { - "UNARY_NEGATIVE": "-", - "UNARY_POSITIVE": "+", - "UNARY_NOT": "~", - } - PARSEABLE_OPS = frozenset( - {"BINARY_OP", "BINARY_SUBSCR", "COMPARE_OP", "CONTAINS_OP", "IS_OP"} - | set(UNARY) - | set(CONTROL_FLOW) - | set(SYNTHETIC) - | LOAD_VALUES - ) - UNARY_VALUES = frozenset(UNARY.values()) - - -# numpy functions that we can map to native expressions -_NUMPY_MODULE_ALIASES = frozenset(("np", "numpy")) -_NUMPY_FUNCTIONS = frozenset( - ( - # "abs", # TODO: this one clashes with Python builtin abs - "arccos", - "arccosh", - "arcsin", - "arcsinh", - "arctan", - "arctanh", - "cbrt", - "ceil", - "cos", - "cosh", - "degrees", - "exp", - "floor", - "log", - "log10", - "log1p", - "radians", - "sign", - "sin", - "sinh", - "sqrt", - "tan", - "tanh", - ) +@deprecate_function( + "It will be removed in the next breaking release." + " The new `register_plugin_function` function has this functionality built in." + " Use `from polars.plugins import register_plugin_function` to import that function." + " Check the user guide for the currently-recommended way to register a plugin:" + " https://docs.pola.rs/user-guide/expressions/plugins", + version="0.20.16", ) - -# python functions that we can map to native expressions -_PYTHON_CASTS_MAP = {"float": "Float64", "int": "Int64", "str": "String"} -_PYTHON_BUILTINS = frozenset(_PYTHON_CASTS_MAP) | {"abs"} -_PYTHON_METHODS_MAP = { - "lower": "str.to_lowercase", - "title": "str.to_titlecase", - "upper": "str.to_uppercase", -} - -_FUNCTION_KINDS: list[dict[str, list[AbstractSet[str]]]] = [ - # lambda x: module.func(CONSTANT) - { - "argument_1_opname": [{"LOAD_CONST"}], - "argument_2_opname": [], - "module_opname": [OpNames.LOAD_ATTR], - "attribute_opname": [], - "module_name": [_NUMPY_MODULE_ALIASES], - "attribute_name": [], - "function_name": [_NUMPY_FUNCTIONS], - }, - # lambda x: module.func(x) - { - "argument_1_opname": [{"LOAD_FAST"}], - "argument_2_opname": [], - "module_opname": [OpNames.LOAD_ATTR], - "attribute_opname": [], - "module_name": [_NUMPY_MODULE_ALIASES], - "attribute_name": [], - "function_name": [_NUMPY_FUNCTIONS], - }, - { - "argument_1_opname": [{"LOAD_FAST"}], - "argument_2_opname": [], - "module_opname": [OpNames.LOAD_ATTR], - "attribute_opname": [], - "module_name": [{"json"}], - "attribute_name": [], - "function_name": [{"loads"}], - }, - # lambda x: module.func(x, CONSTANT) - { - "argument_1_opname": [{"LOAD_FAST"}], - "argument_2_opname": [{"LOAD_CONST"}], - "module_opname": [OpNames.LOAD_ATTR], - "attribute_opname": [], - "module_name": [{"datetime"}], - "attribute_name": [], - "function_name": [{"strptime"}], - }, - # lambda x: module.attribute.func(x, CONSTANT) - { - "argument_1_opname": [{"LOAD_FAST"}], - "argument_2_opname": [{"LOAD_CONST"}], - "module_opname": [{"LOAD_ATTR"}], - "attribute_opname": [OpNames.LOAD_ATTR], - "module_name": [{"datetime", "dt"}], - "attribute_name": [{"datetime"}], - "function_name": [{"strptime"}], - }, -] -# In addition to `lambda x: func(x)`, also support cases when a unary operation -# has been applied to `x`, like `lambda x: func(-x)` or `lambda x: func(~x)`. -_FUNCTION_KINDS = [ - # Dict entry 1 has incompatible type "str": "object"; - # expected "str": "list[AbstractSet[str]]" - {**kind, "argument_1_unary_opname": unary} # type: ignore[dict-item] - for kind in _FUNCTION_KINDS - for unary in [[set(OpNames.UNARY)], []] -] - - -def _get_all_caller_variables() -> dict[str, Any]: - """Get all local and global variables from caller's frame.""" - pkg_dir = Path(__file__).parent.parent - - # https://stackoverflow.com/questions/17407119/python-inspect-stack-is-slow - frame = inspect.currentframe() - n = 0 - try: - while frame: - fname = inspect.getfile(frame) - if fname.startswith(str(pkg_dir)): - frame = frame.f_back - n += 1 - else: - break - variables: dict[str, Any] - if frame is None: - variables = {} - else: - variables = {**frame.f_locals, **frame.f_globals} - finally: - # https://docs.python.org/3/library/inspect.html - # > Though the cycle detector will catch these, destruction of the frames - # > (and local variables) can be made deterministic by removing the cycle - # > in a finally clause. - del frame - return variables - - -class BytecodeParser: - """Introspect UDF bytecode and determine if we can rewrite as native expression.""" - - _map_target_name: str | None = None - - def __init__(self, function: Callable[[Any], Any], map_target: MapTarget): - try: - original_instructions = get_instructions(function) - except TypeError: - # in case we hit something that can't be disassembled (eg: code object - # unavailable, like a bare numpy ufunc that isn't in a lambda/function) - original_instructions = iter([]) - - self._function = function - self._map_target = map_target - self._param_name = self._get_param_name(function) - self._rewritten_instructions = RewrittenInstructions( - instructions=original_instructions, - ) - - @staticmethod - def _get_param_name(function: Callable[[Any], Any]) -> str | None: - """Return single function parameter name.""" - try: - # note: we do not parse/handle functions with > 1 params - sig = signature(function) - except ValueError: - return None - return ( - next(iter(parameters.keys())) - if len(parameters := sig.parameters) == 1 - else None - ) - - def _inject_nesting( - self, - expression_blocks: dict[int, str], - logical_instructions: list[Instruction], - ) -> list[tuple[int, str]]: - """Inject nesting boundaries into expression blocks (as parentheses).""" - if logical_instructions: - # reconstruct nesting boundaries for mixed and/or ops by associating control - # flow jump offsets with their target expression blocks and applying parens - if len({inst.opname for inst in logical_instructions}) > 1: - block_offsets: list[int] = list(expression_blocks.keys()) - prev_end = -1 - for inst in logical_instructions: - start = block_offsets[bisect_left(block_offsets, inst.offset) - 1] - end = block_offsets[bisect_left(block_offsets, inst.argval) - 1] - if not (start == 0 and end == block_offsets[-1]): - if prev_end not in (start, end): - expression_blocks[start] = "(" + expression_blocks[start] - expression_blocks[end] += ")" - prev_end = end - - for inst in logical_instructions: # inject connecting "&" and "|" ops - expression_blocks[inst.offset] = OpNames.CONTROL_FLOW[inst.opname] - - return sorted(expression_blocks.items()) - - def _get_target_name(self, col: str, expression: str) -> str: - """The name of the object against which the 'map' is being invoked.""" - if self._map_target_name is not None: - return self._map_target_name - else: - col_expr = f'pl.col("{col}")' - if self._map_target == "expr": - return col_expr - elif self._map_target == "series": - # note: handle overlapping name from global variables; fallback - # through "s", "srs", "series" and (finally) srs0 -> srsN... - search_expr = expression.replace(col_expr, "") - for name in ("s", "srs", "series"): - if not re.search(rf"\b{name}\b", search_expr): - self._map_target_name = name - return name - n = count() - while True: - name = f"srs{next(n)}" - if not re.search(rf"\b{name}\b", search_expr): - self._map_target_name = name - return name - - msg = f"TODO: map_target = {self._map_target!r}" - raise NotImplementedError(msg) - - @property - def map_target(self) -> MapTarget: - """The map target, eg: one of 'expr', 'frame', or 'series'.""" - return self._map_target - - def can_attempt_rewrite(self) -> bool: - """ - Determine if we may be able to offer a native polars expression instead. - - Note that `lambda x: x` is inefficient, but we ignore it because it is not - guaranteed that using the equivalent bare constant value will return the - same output. (Hopefully nobody is writing lambdas like that anyway...) - """ - return ( - self._param_name is not None - # check minimum number of ops, ensuring all are parseable - and len(self._rewritten_instructions) >= 2 - and all( - inst.opname in OpNames.PARSEABLE_OPS - for inst in self._rewritten_instructions - ) - # exclude constructs/functions with multiple RETURN_VALUE ops - and sum( - 1 - for inst in self.original_instructions - if inst.opname == "RETURN_VALUE" - ) - == 1 - ) - - def dis(self) -> None: - """Print disassembled function bytecode.""" - dis.dis(self._function) - - @property - def function(self) -> Callable[[Any], Any]: - """The function being parsed.""" - return self._function - - @property - def original_instructions(self) -> list[Instruction]: - """The original bytecode instructions from the function we are parsing.""" - return list(self._rewritten_instructions._original_instructions) - - @property - def param_name(self) -> str | None: - """The parameter name of the function being parsed.""" - return self._param_name - - @property - def rewritten_instructions(self) -> list[Instruction]: - """The rewritten bytecode instructions from the function we are parsing.""" - return list(self._rewritten_instructions) - - def to_expression(self, col: str) -> str | None: - """Translate postfix bytecode instructions to polars expression/string.""" - self._map_target_name = None - if self._param_name is None: - return None - - # decompose bytecode into logical 'and'/'or' expression blocks (if present) - control_flow_blocks = defaultdict(list) - logical_instructions = [] - jump_offset = 0 - for idx, inst in enumerate(self._rewritten_instructions): - if inst.opname in OpNames.CONTROL_FLOW: - jump_offset = self._rewritten_instructions[idx + 1].offset - logical_instructions.append(inst) - else: - control_flow_blocks[jump_offset].append(inst) - - # convert each block to a polars expression string - caller_variables: dict[str, Any] = {} - try: - expression_strings = self._inject_nesting( - { - offset: InstructionTranslator( - instructions=ops, - caller_variables=caller_variables, - map_target=self._map_target, - ).to_expression( - col=col, - param_name=self._param_name, - depth=int(bool(logical_instructions)), - ) - for offset, ops in control_flow_blocks.items() - }, - logical_instructions, - ) - except NotImplementedError: - return None - polars_expr = " ".join(expr for _offset, expr in expression_strings) - - # note: if no 'pl.col' in the expression, it likely represents a compound - # constant value (e.g. `lambda x: CONST + 123`), so we don't want to warn - if "pl.col(" not in polars_expr: - return None - elif self._map_target == "series": - target_name = self._get_target_name(col, polars_expr) - return polars_expr.replace(f'pl.col("{col}")', target_name) - else: - return polars_expr - - def warn( - self, - col: str, - suggestion_override: str | None = None, - udf_override: str | None = None, - ) -> None: - """Generate warning that suggests an equivalent native polars expression.""" - # Import these here so that udfs can be imported without polars installed. - - from polars.exceptions import PolarsInefficientMapWarning - from polars.utils.various import ( - find_stacklevel, - in_terminal_that_supports_colour, - ) - - suggested_expression = suggestion_override or self.to_expression(col) - - if suggested_expression is not None: - target_name = self._get_target_name(col, suggested_expression) - func_name = udf_override or self._function.__name__ or "..." - if func_name == "": - func_name = f"lambda {self._param_name}: ..." - - addendum = ( - 'Note: in list.eval context, pl.col("") should be written as pl.element()' - if 'pl.col("")' in suggested_expression - else "" - ) - if self._map_target == "expr": - apitype = "expressions" - clsname = "Expr" - else: - apitype = "series" - clsname = "Series" - - before, after = ( - ( - f" \033[31m- {target_name}.map_elements({func_name})\033[0m\n", - f" \033[32m+ {suggested_expression}\033[0m\n{addendum}", - ) - if in_terminal_that_supports_colour() - else ( - f" - {target_name}.map_elements({func_name})\n", - f" + {suggested_expression}\n{addendum}", - ) - ) - warnings.warn( - f"\n{clsname}.map_elements is significantly slower than the native {apitype} API.\n" - "Only use if you absolutely CANNOT implement your logic otherwise.\n" - "Replace this expression...\n" - f"{before}" - "with this one instead:\n" - f"{after}", - PolarsInefficientMapWarning, - stacklevel=find_stacklevel(), - ) - - -class InstructionTranslator: - """Translates Instruction bytecode to a polars expression string.""" - - def __init__( - self, - instructions: list[Instruction], - caller_variables: dict[str, Any], - map_target: MapTarget, - ) -> None: - self._caller_variables: dict[str, Any] = caller_variables - self._stack = self._to_intermediate_stack(instructions, map_target) - - def to_expression(self, col: str, param_name: str, depth: int) -> str: - """Convert intermediate stack to polars expression string.""" - return self._expr(self._stack, col, param_name, depth) - - @staticmethod - def op(inst: Instruction) -> str: - """Convert bytecode instruction to suitable intermediate op string.""" - if inst.opname in OpNames.CONTROL_FLOW: - return OpNames.CONTROL_FLOW[inst.opname] - elif inst.argrepr: - return inst.argrepr - elif inst.opname == "IS_OP": - return "is not" if inst.argval else "is" - elif inst.opname == "CONTAINS_OP": - return "not in" if inst.argval else "in" - elif inst.opname in OpNames.UNARY: - return OpNames.UNARY[inst.opname] - elif inst.opname == "BINARY_SUBSCR": - return "replace" - else: - msg = ( - "unrecognized opname" - "\n\nPlease report a bug to https://github.com/pola-rs/polars/issues" - " with the content of function you were passing to `map` and the" - f" following instruction object:\n{inst!r}" - ) - raise AssertionError(msg) - - def _expr(self, value: StackEntry, col: str, param_name: str, depth: int) -> str: - """Take stack entry value and convert to polars expression string.""" - if isinstance(value, StackValue): - op = value.operator - e1 = self._expr(value.left_operand, col, param_name, depth + 1) - if value.operator_arity == 1: - if op not in OpNames.UNARY_VALUES: - if e1.startswith("pl.col("): - call = "" if op.endswith(")") else "()" - return f"{e1}.{op}{call}" - if e1[0] in OpNames.UNARY_VALUES and e1[1:].startswith("pl.col("): - call = "" if op.endswith(")") else "()" - return f"({e1}).{op}{call}" - - # support use of consts as numpy/builtin params, eg: - # "np.sin(3) + np.cos(x)", or "len('const_string') + len(x)" - pfx = "np." if op in _NUMPY_FUNCTIONS else "" - return f"{pfx}{op}({e1})" - return f"{op}{e1}" - else: - e2 = self._expr(value.right_operand, col, param_name, depth + 1) - if op in ("is", "is not") and value[2] == "None": - not_ = "" if op == "is" else "not_" - return f"{e1}.is_{not_}null()" - elif op in ("in", "not in"): - not_ = "" if op == "in" else "~" - return ( - f"{not_}({e1}.is_in({e2}))" - if " " in e1 - else f"{not_}{e1}.is_in({e2})" - ) - elif op == "replace": - if not self._caller_variables: - self._caller_variables.update(_get_all_caller_variables()) - if not isinstance(self._caller_variables.get(e1, None), dict): - msg = "require dict mapping" - raise NotImplementedError(msg) - return f"{e2}.{op}({e1})" - elif op == "<<": - # Result of 2**e2 might be float is e2 was negative. - # But, if e1 << e2 was valid, then e2 must have been positive. - # Hence, the output of 2**e2 can be safely cast to Int64, which - # may be necessary if chaining operations which assume Int64 output. - return f"({e1}*2**{e2}).cast(pl.Int64)" - elif op == ">>": - # Motivation for the cast is the same as in the '<<' case above. - return f"({e1} / 2**{e2}).cast(pl.Int64)" - else: - expr = f"{e1} {op} {e2}" - return f"({expr})" if depth else expr - - elif value == param_name: - return f'pl.col("{col}")' - - return value - - def _to_intermediate_stack( - self, instructions: list[Instruction], map_target: MapTarget - ) -> StackEntry: - """Take postfix bytecode and convert to an intermediate natural-order stack.""" - if map_target in ("expr", "series"): - stack: list[StackEntry] = [] - for inst in instructions: - stack.append( - inst.argrepr - if inst.opname in OpNames.LOAD - else ( - StackValue( - operator=self.op(inst), - operator_arity=1, - left_operand=stack.pop(), # type: ignore[arg-type] - right_operand=None, # type: ignore[arg-type] - ) - if ( - inst.opname in OpNames.UNARY - or OpNames.SYNTHETIC.get(inst.opname) == 1 - ) - else StackValue( - operator=self.op(inst), - operator_arity=2, - left_operand=stack.pop(-2), # type: ignore[arg-type] - right_operand=stack.pop(-1), # type: ignore[arg-type] - ) - ) - ) - return stack[0] - - # TODO: dataframe.apply(...) - msg = f"TODO: {map_target!r} apply" - raise NotImplementedError(msg) - - -class RewrittenInstructions: - """ - Standalone class that applies Instruction rewrite/filtering rules. - - This significantly simplifies subsequent parsing by injecting - synthetic POLARS_EXPRESSION ops into the Instruction stream for - easy identification/translation and separates the parsing logic - from the identification of expression translation opportunities. +def _get_shared_lib_location(main_file: Any) -> str: """ + Get the location of the dynamic library file. - _ignored_ops = frozenset( - [ - "COPY", - "COPY_FREE_VARS", - "POP_TOP", - "PRECALL", - "PUSH_NULL", - "RESUME", - "RETURN_VALUE", - ] - ) - _caller_variables: ClassVar[dict[str, Any]] = {} - - def __init__(self, instructions: Iterator[Instruction]): - self._original_instructions = list(instructions) - self._rewritten_instructions = self._rewrite( - self._upgrade_instruction(inst) - for inst in self._original_instructions - if inst.opname not in self._ignored_ops - ) - - def __len__(self) -> int: - return len(self._rewritten_instructions) - - def __iter__(self) -> Iterator[Instruction]: - return iter(self._rewritten_instructions) - - def __getitem__(self, item: Any) -> Instruction: - return self._rewritten_instructions[item] - - def _matches( - self, - idx: int, - *, - opnames: list[AbstractSet[str]], - argvals: list[AbstractSet[Any] | dict[Any, Any]] | None, - ) -> list[Instruction]: - """ - Check if a sequence of Instructions matches the specified ops/argvals. - - Parameters - ---------- - idx - The index of the first instruction to check. - opnames - The full opname sequence that defines a match. - argvals - Associated argvals that must also match (in same position as opnames). - """ - n_required_ops, argvals = len(opnames), argvals or [] - instructions = self._instructions[idx : idx + n_required_ops] - if len(instructions) == n_required_ops and all( - inst.opname in match_opnames - and (match_argval is None or inst.argval in match_argval) - for inst, match_opnames, match_argval in zip_longest( - instructions, opnames, argvals - ) - ): - return instructions - return [] - - def _rewrite(self, instructions: Iterator[Instruction]) -> list[Instruction]: - """ - Apply rewrite rules, potentially injecting synthetic operations. - - Rules operate on the instruction stream and can examine/modify - it as needed, pushing updates into "updated_instructions" and - returning True/False to indicate if any changes were made. - """ - self._instructions = list(instructions) - updated_instructions: list[Instruction] = [] - idx = 0 - while idx < len(self._instructions): - inst, increment = self._instructions[idx], 1 - if inst.opname not in OpNames.LOAD or not any( - (increment := map_rewrite(idx, updated_instructions)) - for map_rewrite in ( - # add any other rewrite methods here - self._rewrite_functions, - self._rewrite_methods, - self._rewrite_builtins, - ) - ): - updated_instructions.append(inst) - idx += increment or 1 - return updated_instructions - - def _rewrite_builtins( - self, idx: int, updated_instructions: list[Instruction] - ) -> int: - """Replace builtin function calls with a synthetic POLARS_EXPRESSION op.""" - if matching_instructions := self._matches( - idx, - opnames=[{"LOAD_GLOBAL"}, {"LOAD_FAST", "LOAD_CONST"}, OpNames.CALL], - argvals=[_PYTHON_BUILTINS], - ): - inst1, inst2 = matching_instructions[:2] - if (argval := inst1.argval) in _PYTHON_CASTS_MAP: - dtype = _PYTHON_CASTS_MAP[argval] - argval = f"cast(pl.{dtype})" - - synthetic_call = inst1._replace( - opname="POLARS_EXPRESSION", - argval=argval, - argrepr=argval, - offset=inst2.offset, - ) - # POLARS_EXPRESSION is mapped as a unary op, so switch instruction order - operand = inst2._replace(offset=inst1.offset) - updated_instructions.extend((operand, synthetic_call)) - - return len(matching_instructions) + .. deprecated:: 0.20.16 + Use :func:`polars.plugins.register_plugin_function` instead. - def _rewrite_functions( - self, idx: int, updated_instructions: list[Instruction] - ) -> int: - """Replace function calls with a synthetic POLARS_EXPRESSION op.""" - for function_kind in _FUNCTION_KINDS: - opnames: list[AbstractSet[str]] = [ - {"LOAD_GLOBAL", "LOAD_DEREF"}, - *function_kind["module_opname"], - *function_kind["attribute_opname"], - *function_kind["argument_1_opname"], - *function_kind["argument_1_unary_opname"], - *function_kind["argument_2_opname"], - OpNames.CALL, - ] - if matching_instructions := self._matches( - idx, - opnames=opnames, - argvals=[ - *function_kind["module_name"], - *function_kind["attribute_name"], - *function_kind["function_name"], - ], - ): - attribute_count = len(function_kind["attribute_name"]) - inst1, inst2, inst3 = matching_instructions[ - attribute_count : 3 + attribute_count - ] - if inst1.argval == "json": - expr_name = "str.json_decode" - elif inst1.argval == "datetime": - fmt = matching_instructions[attribute_count + 3].argval - expr_name = f'str.to_datetime(format="{fmt}")' - if not self._is_stdlib_datetime( - inst1.argval, - matching_instructions[0].argval, - fmt, - attribute_count, - ): - return 0 - else: - expr_name = inst2.argval - synthetic_call = inst1._replace( - opname="POLARS_EXPRESSION", - argval=expr_name, - argrepr=expr_name, - offset=inst3.offset, - ) - # POLARS_EXPRESSION is mapped as a unary op, so switch instruction order - operand = inst3._replace(offset=inst1.offset) - updated_instructions.extend( - ( - operand, - matching_instructions[3 + attribute_count], - synthetic_call, - ) - if function_kind["argument_1_unary_opname"] - else (operand, synthetic_call) - ) - return len(matching_instructions) + Warnings + -------- + This function is deprecated and will be removed in the next breaking release. + The new `polars.plugins.register_plugin_function` function has this + functionality built in. Use `from polars.plugins import register_plugin_function` + to import that function. - return 0 - - def _rewrite_methods( - self, idx: int, updated_instructions: list[Instruction] - ) -> int: - """Replace python method calls with synthetic POLARS_EXPRESSION op.""" - if matching_instructions := self._matches( - idx, - opnames=[ - OpNames.LOAD_ATTR if _MIN_PY312 else {"LOAD_METHOD"}, - OpNames.CALL, - ], - argvals=[_PYTHON_METHODS_MAP], - ): - inst = matching_instructions[0] - expr_name = _PYTHON_METHODS_MAP[inst.argval] - synthetic_call = inst._replace( - opname="POLARS_EXPRESSION", argval=expr_name, argrepr=expr_name - ) - updated_instructions.append(synthetic_call) - - return len(matching_instructions) - - @staticmethod - def _upgrade_instruction(inst: Instruction) -> Instruction: - """Rewrite any older binary opcodes using py 3.11 'BINARY_OP' instead.""" - if not _MIN_PY311 and inst.opname in OpNames.BINARY: - inst = inst._replace( - argrepr=OpNames.BINARY[inst.opname], - opname="BINARY_OP", - ) - return inst - - def _is_stdlib_datetime( - self, function_name: str, module_name: str, fmt: str, attribute_count: int - ) -> bool: - if not self._caller_variables: - self._caller_variables.update(_get_all_caller_variables()) - vars = self._caller_variables - return ( - attribute_count == 0 and vars.get(function_name) is datetime.datetime - ) or (attribute_count == 1 and vars.get(module_name) is datetime) - - -def _is_raw_function(function: Callable[[Any], Any]) -> tuple[str, str]: - """Identify translatable calls that aren't wrapped inside a lambda/function.""" - try: - func_module = function.__class__.__module__ - func_name = function.__name__ - except AttributeError: - return "", "" - - # numpy function calls - if func_module == "numpy" and func_name in _NUMPY_FUNCTIONS: - return "np", f"{func_name}()" - - # python function calls - elif func_module == "builtins": - if func_name in _PYTHON_CASTS_MAP: - return "builtins", f"cast(pl.{_PYTHON_CASTS_MAP[func_name]})" - elif func_name == "loads": - import json # double-check since it is referenced via 'builtins' - - if function is json.loads: - return "json", "str.json_decode()" - - return "", "" - - -def warn_on_inefficient_map( - function: Callable[[Any], Any], columns: list[str], map_target: MapTarget -) -> None: + Check the user guide for the recommended way to register a plugin: + https://docs.pola.rs/user-guide/expressions/plugins """ - Generate `PolarsInefficientMapWarning` on poor usage of a `map` function. - - Parameters - ---------- - function - The function passed to `map`. - columns - The column names of the original object; in the case of an `Expr` this - will be a list of length 1 containing the expression's root name. - map_target - The target of the `map` call. One of `"expr"`, `"frame"`, - or `"series"`. - """ - if map_target == "frame": - msg = "TODO: 'frame' map-function parsing" - raise NotImplementedError(msg) - - # note: we only consider simple functions with a single col/param - if not (col := columns and columns[0]): - return None - - # the parser introspects function bytecode to determine if we can - # rewrite as a much more optimal native polars expression instead - parser = BytecodeParser(function, map_target) - if parser.can_attempt_rewrite(): - parser.warn(col) - else: - # handle bare numpy/json functions - module, suggestion = _is_raw_function(function) - if module and suggestion: - fn = function.__name__ - parser.warn( - col, - suggestion_override=f'pl.col("{col}").{suggestion}', - udf_override=fn if module == "builtins" else f"{module}.{fn}", - ) - - -def is_shared_lib(file: str) -> bool: - return file.endswith((".so", ".dll", ".pyd")) - - -def _get_shared_lib_location(main_file: Any) -> str: - import os - directory = os.path.dirname(main_file) # noqa: PTH120 return os.path.join( # noqa: PTH118 - directory, next(filter(is_shared_lib, os.listdir(directory))) + directory, next(filter(_is_shared_lib, os.listdir(directory))) ) -__all__ = ["BytecodeParser", "warn_on_inefficient_map", "_get_shared_lib_location"] +def _is_shared_lib(file: str) -> bool: + return file.endswith((".so", ".dll", ".pyd")) diff --git a/py-polars/pyproject.toml b/py-polars/pyproject.toml index 4911687cea2ea..b21c65f12e918 100644 --- a/py-polars/pyproject.toml +++ b/py-polars/pyproject.toml @@ -39,17 +39,18 @@ Changelog = "https://github.com/pola-rs/polars/releases" [project.optional-dependencies] # NOTE: keep this list in sync with show_versions() and requirements-dev.txt -adbc = ["adbc_driver_sqlite"] +adbc = ["adbc_driver_manager", "adbc_driver_sqlite"] cloudpickle = ["cloudpickle"] connectorx = ["connectorx >= 0.3.2"] deltalake = ["deltalake >= 0.14.0"] +fastexcel = ["fastexcel >= 0.9"] fsspec = ["fsspec"] gevent = ["gevent"] -plot = ["hvplot >= 0.9.1"] matplotlib = ["matplotlib"] numpy = ["numpy >= 1.16.0"] openpyxl = ["openpyxl >= 3.0.0"] pandas = ["pyarrow >= 7.0.0", "pandas"] +plot = ["hvplot >= 0.9.1"] pyarrow = ["pyarrow >= 7.0.0"] pydantic = ["pydantic"] pyiceberg = ["pyiceberg >= 0.5.0"] @@ -59,7 +60,7 @@ timezone = ["backports.zoneinfo; python_version < '3.9'", "tzdata; platform_syst xlsx2csv = ["xlsx2csv >= 0.8.0"] xlsxwriter = ["xlsxwriter"] all = [ - "polars[pyarrow,pandas,numpy,fsspec,plot,connectorx,xlsx2csv,deltalake,timezone,pydantic,pyiceberg,sqlalchemy,xlsxwriter,adbc,cloudpickle,gevent]", + "polars[adbc,cloudpickle,connectorx,deltalake,fastexcel,fsspec,gevent,numpy,pandas,plot,pyarrow,pydantic,pyiceberg,sqlalchemy,timezone,xlsx2csv,xlsxwriter]", ] [tool.maturin] @@ -90,6 +91,7 @@ module = [ "fsspec.*", "gevent", "hvplot.*", + "kuzu", "matplotlib.*", "moto.server", "openpyxl", @@ -179,7 +181,7 @@ ignore = [ ] [tool.ruff.lint.per-file-ignores] -"tests/**/*.py" = ["D100", "D103", "B018", "FBT001"] +"tests/**/*.py" = ["D100", "D102", "D103", "B018", "FBT001"] [tool.ruff.lint.pycodestyle] max-doc-length = 88 diff --git a/py-polars/requirements-dev.txt b/py-polars/requirements-dev.txt index aeb9d3be53bd6..abf2550ec9c19 100644 --- a/py-polars/requirements-dev.txt +++ b/py-polars/requirements-dev.txt @@ -2,14 +2,13 @@ # We're not pinning package dependencies, because our tests need to pass with the # latest version of the packages. ---prefer-binary - # ----- # BUILD # ----- maturin -patchelf; platform_system == 'Linux' # Extra dependency for maturin, only for Linux +# extra dependency for maturin (linux-only) +patchelf; platform_system == 'Linux' # ------------ # DEPENDENCIES @@ -30,6 +29,7 @@ adbc_driver_sqlite; python_version >= '3.9' and platform_system != 'Windows' # TODO: Remove version constraint for connectorx when Python 3.12 is supported: # https://github.com/sfu-db/connector-x/issues/527 connectorx; python_version <= '3.11' +kuzu # Cloud cloudpickle fsspec @@ -37,7 +37,7 @@ s3fs[boto3] # Spreadsheet ezodf lxml -fastexcel>=0.8.0 +fastexcel>=0.9 openpyxl pyxlsb xlsx2csv diff --git a/py-polars/requirements-lint.txt b/py-polars/requirements-lint.txt index 225616bb2c75c..ef37452db82c4 100644 --- a/py-polars/requirements-lint.txt +++ b/py-polars/requirements-lint.txt @@ -1,3 +1,3 @@ mypy==1.8.0 -ruff==0.2.0 -typos==1.17.2 +ruff==0.3.0 +typos==1.18.2 diff --git a/py-polars/src/arrow_interop/to_rust.rs b/py-polars/src/arrow_interop/to_rust.rs index 612cc64450982..ecdd13e1a364c 100644 --- a/py-polars/src/arrow_interop/to_rust.rs +++ b/py-polars/src/arrow_interop/to_rust.rs @@ -98,7 +98,7 @@ pub fn to_rust_df(rb: &[&PyAny]) -> PyResult { }?; // no need to check as a record batch has the same guarantees - Ok(DataFrame::new_no_checks(columns)) + Ok(unsafe { DataFrame::new_no_checks(columns) }) }) .collect::>>()?; diff --git a/py-polars/src/conversion/any_value.rs b/py-polars/src/conversion/any_value.rs index a66ec63d5354f..2c524365bcb02 100644 --- a/py-polars/src/conversion/any_value.rs +++ b/py-polars/src/conversion/any_value.rs @@ -14,84 +14,7 @@ use crate::series::PySeries; impl IntoPy for Wrap> { fn into_py(self, py: Python) -> PyObject { - let utils = UTILS.as_ref(py); - match self.0 { - AnyValue::UInt8(v) => v.into_py(py), - AnyValue::UInt16(v) => v.into_py(py), - AnyValue::UInt32(v) => v.into_py(py), - AnyValue::UInt64(v) => v.into_py(py), - AnyValue::Int8(v) => v.into_py(py), - AnyValue::Int16(v) => v.into_py(py), - AnyValue::Int32(v) => v.into_py(py), - AnyValue::Int64(v) => v.into_py(py), - AnyValue::Float32(v) => v.into_py(py), - AnyValue::Float64(v) => v.into_py(py), - AnyValue::Null => py.None(), - AnyValue::Boolean(v) => v.into_py(py), - AnyValue::String(v) => v.into_py(py), - AnyValue::StringOwned(v) => v.into_py(py), - AnyValue::Categorical(idx, rev, arr) | AnyValue::Enum(idx, rev, arr) => { - let s = if arr.is_null() { - rev.get(idx) - } else { - unsafe { arr.deref_unchecked().value(idx as usize) } - }; - s.into_py(py) - }, - AnyValue::Date(v) => { - let convert = utils.getattr(intern!(py, "_to_python_date")).unwrap(); - convert.call1((v,)).unwrap().into_py(py) - }, - AnyValue::Datetime(v, time_unit, time_zone) => { - let convert = utils.getattr(intern!(py, "_to_python_datetime")).unwrap(); - let time_unit = time_unit.to_ascii(); - convert - .call1((v, time_unit, time_zone.as_ref().map(|s| s.as_str()))) - .unwrap() - .into_py(py) - }, - AnyValue::Duration(v, time_unit) => { - let convert = utils.getattr(intern!(py, "_to_python_timedelta")).unwrap(); - let time_unit = time_unit.to_ascii(); - convert.call1((v, time_unit)).unwrap().into_py(py) - }, - AnyValue::Time(v) => { - let convert = utils.getattr(intern!(py, "_to_python_time")).unwrap(); - convert.call1((v,)).unwrap().into_py(py) - }, - AnyValue::Array(v, _) | AnyValue::List(v) => PySeries::new(v).to_list(), - ref av @ AnyValue::Struct(_, _, flds) => struct_dict(py, av._iter_struct_av(), flds), - AnyValue::StructOwned(payload) => struct_dict(py, payload.0.into_iter(), &payload.1), - #[cfg(feature = "object")] - AnyValue::Object(v) => { - let object = v.as_any().downcast_ref::().unwrap(); - object.inner.clone() - }, - #[cfg(feature = "object")] - AnyValue::ObjectOwned(v) => { - let object = v.0.as_any().downcast_ref::().unwrap(); - object.inner.clone() - }, - AnyValue::Binary(v) => v.into_py(py), - AnyValue::BinaryOwned(v) => v.into_py(py), - AnyValue::Decimal(v, scale) => { - let convert = utils.getattr(intern!(py, "_to_python_decimal")).unwrap(); - const N: usize = 3; - let mut buf = [0_u128; N]; - let n_digits = decimal_to_digits(v.abs(), &mut buf); - let buf = unsafe { - std::slice::from_raw_parts( - buf.as_slice().as_ptr() as *const u8, - N * std::mem::size_of::(), - ) - }; - let digits = PyTuple::new(py, buf.iter().take(n_digits)); - convert - .call1((v.is_negative() as u8, digits, n_digits, -(scale as i32))) - .unwrap() - .into_py(py) - }, - } + any_value_into_py_object(self.0, py) } } @@ -101,316 +24,388 @@ impl ToPyObject for Wrap> { } } +impl<'s> FromPyObject<'s> for Wrap> { + fn extract(ob: &'s PyAny) -> PyResult { + py_object_to_any_value(ob, true).map(Wrap) + } +} + +pub(crate) fn any_value_into_py_object(av: AnyValue, py: Python) -> PyObject { + let utils = UTILS.as_ref(py); + match av { + AnyValue::UInt8(v) => v.into_py(py), + AnyValue::UInt16(v) => v.into_py(py), + AnyValue::UInt32(v) => v.into_py(py), + AnyValue::UInt64(v) => v.into_py(py), + AnyValue::Int8(v) => v.into_py(py), + AnyValue::Int16(v) => v.into_py(py), + AnyValue::Int32(v) => v.into_py(py), + AnyValue::Int64(v) => v.into_py(py), + AnyValue::Float32(v) => v.into_py(py), + AnyValue::Float64(v) => v.into_py(py), + AnyValue::Null => py.None(), + AnyValue::Boolean(v) => v.into_py(py), + AnyValue::String(v) => v.into_py(py), + AnyValue::StringOwned(v) => v.into_py(py), + AnyValue::Categorical(idx, rev, arr) | AnyValue::Enum(idx, rev, arr) => { + let s = if arr.is_null() { + rev.get(idx) + } else { + unsafe { arr.deref_unchecked().value(idx as usize) } + }; + s.into_py(py) + }, + AnyValue::Date(v) => { + let convert = utils.getattr(intern!(py, "to_py_date")).unwrap(); + convert.call1((v,)).unwrap().into_py(py) + }, + AnyValue::Datetime(v, time_unit, time_zone) => { + let convert = utils.getattr(intern!(py, "to_py_datetime")).unwrap(); + let time_unit = time_unit.to_ascii(); + convert + .call1((v, time_unit, time_zone.as_ref().map(|s| s.as_str()))) + .unwrap() + .into_py(py) + }, + AnyValue::Duration(v, time_unit) => { + let convert = utils.getattr(intern!(py, "to_py_timedelta")).unwrap(); + let time_unit = time_unit.to_ascii(); + convert.call1((v, time_unit)).unwrap().into_py(py) + }, + AnyValue::Time(v) => { + let convert = utils.getattr(intern!(py, "to_py_time")).unwrap(); + convert.call1((v,)).unwrap().into_py(py) + }, + AnyValue::Array(v, _) | AnyValue::List(v) => PySeries::new(v).to_list(), + ref av @ AnyValue::Struct(_, _, flds) => struct_dict(py, av._iter_struct_av(), flds), + AnyValue::StructOwned(payload) => struct_dict(py, payload.0.into_iter(), &payload.1), + #[cfg(feature = "object")] + AnyValue::Object(v) => { + let object = v.as_any().downcast_ref::().unwrap(); + object.inner.clone() + }, + #[cfg(feature = "object")] + AnyValue::ObjectOwned(v) => { + let object = v.0.as_any().downcast_ref::().unwrap(); + object.inner.clone() + }, + AnyValue::Binary(v) => v.into_py(py), + AnyValue::BinaryOwned(v) => v.into_py(py), + AnyValue::Decimal(v, scale) => { + let convert = utils.getattr(intern!(py, "to_py_decimal")).unwrap(); + const N: usize = 3; + let mut buf = [0_u128; N]; + let n_digits = decimal_to_digits(v.abs(), &mut buf); + let buf = unsafe { + std::slice::from_raw_parts( + buf.as_slice().as_ptr() as *const u8, + N * std::mem::size_of::(), + ) + }; + let digits = PyTuple::new(py, buf.iter().take(n_digits)); + convert + .call1((v.is_negative() as u8, digits, n_digits, -(scale as i32))) + .unwrap() + .into_py(py) + }, + } +} + type TypeObjectPtr = usize; -type InitFn = fn(&PyAny) -> PyResult>>; +type InitFn = fn(&PyAny, bool) -> PyResult; pub(crate) static LUT: crate::gil_once_cell::GILOnceCell> = crate::gil_once_cell::GILOnceCell::new(); -impl<'s> FromPyObject<'s> for Wrap> { - fn extract(ob: &'s PyAny) -> PyResult { - // conversion functions - fn get_bool(ob: &PyAny) -> PyResult>> { - Ok(AnyValue::Boolean(ob.extract::().unwrap()).into()) - } +pub(crate) fn py_object_to_any_value(ob: &PyAny, strict: bool) -> PyResult { + // conversion functions + fn get_bool(ob: &PyAny, _strict: bool) -> PyResult { + let b = ob.extract::().unwrap(); + Ok(AnyValue::Boolean(b)) + } - fn get_int(ob: &PyAny) -> PyResult>> { - // can overflow - match ob.extract::() { - Ok(v) => Ok(AnyValue::Int64(v).into()), - Err(_) => Ok(AnyValue::UInt64(ob.extract::()?).into()), - } + fn get_int(ob: &PyAny, _strict: bool) -> PyResult { + // can overflow + match ob.extract::() { + Ok(v) => Ok(AnyValue::Int64(v)), + Err(_) => Ok(AnyValue::UInt64(ob.extract::()?)), } + } - fn get_float(ob: &PyAny) -> PyResult>> { - Ok(AnyValue::Float64(ob.extract::().unwrap()).into()) - } + fn get_float(ob: &PyAny, _strict: bool) -> PyResult { + Ok(AnyValue::Float64(ob.extract::().unwrap())) + } - fn get_str(ob: &PyAny) -> PyResult>> { - let value = ob.extract::<&str>().unwrap(); - Ok(AnyValue::String(value).into()) - } + fn get_str(ob: &PyAny, _strict: bool) -> PyResult { + let value = ob.extract::<&str>().unwrap(); + Ok(AnyValue::String(value)) + } - fn get_struct(ob: &PyAny) -> PyResult>> { - let dict = ob.downcast::().unwrap(); - let len = dict.len(); - let mut keys = Vec::with_capacity(len); - let mut vals = Vec::with_capacity(len); - for (k, v) in dict.into_iter() { - let key = k.extract::<&str>()?; - let val = v.extract::>()?.0; - let dtype = DataType::from(&val); - keys.push(Field::new(key, dtype)); - vals.push(val) - } - Ok(Wrap(AnyValue::StructOwned(Box::new((vals, keys))))) + fn get_struct(ob: &PyAny, strict: bool) -> PyResult> { + let dict = ob.downcast::().unwrap(); + let len = dict.len(); + let mut keys = Vec::with_capacity(len); + let mut vals = Vec::with_capacity(len); + for (k, v) in dict.into_iter() { + let key = k.extract::<&str>()?; + let val = py_object_to_any_value(v, strict)?; + let dtype = val.dtype(); + keys.push(Field::new(key, dtype)); + vals.push(val) } + Ok(AnyValue::StructOwned(Box::new((vals, keys)))) + } - fn get_list(ob: &PyAny) -> PyResult> { - fn get_list_with_constructor(ob: &PyAny) -> PyResult> { - // Use the dedicated constructor - // this constructor is able to go via dedicated type constructors - // so it can be much faster - Python::with_gil(|py| { - let s = SERIES.call1(py, (ob,))?; - get_series_el(s.as_ref(py)) - }) - } + fn get_list(ob: &PyAny, strict: bool) -> PyResult { + fn get_list_with_constructor(ob: &PyAny) -> PyResult { + // Use the dedicated constructor + // this constructor is able to go via dedicated type constructors + // so it can be much faster + Python::with_gil(|py| { + let s = SERIES.call1(py, (ob,))?; + get_series_el(s.as_ref(py), true) + }) + } - if ob.is_empty()? { - Ok(Wrap(AnyValue::List(Series::new_empty("", &DataType::Null)))) - } else if ob.is_instance_of::() | ob.is_instance_of::() { - let list = ob.downcast::().unwrap(); + if ob.is_empty()? { + Ok(AnyValue::List(Series::new_empty("", &DataType::Null))) + } else if ob.is_instance_of::() | ob.is_instance_of::() { + const INFER_SCHEMA_LENGTH: usize = 25; - let mut avs = Vec::with_capacity(25); - let mut iter = list.iter()?; + let list = ob.downcast::().unwrap(); - for item in (&mut iter).take(25) { - avs.push(item?.extract::>()?.0) - } + let mut avs = Vec::with_capacity(INFER_SCHEMA_LENGTH); + let mut iter = list.iter()?; - let (dtype, n_types) = any_values_to_dtype(&avs).map_err(PyPolarsErr::from)?; + for item in (&mut iter).take(INFER_SCHEMA_LENGTH) { + let av = py_object_to_any_value(item?, strict)?; + avs.push(av) + } - // we only take this path if there is no question of the data-type - if dtype.is_primitive() && n_types == 1 { - get_list_with_constructor(ob) - } else { - // push the rest - avs.reserve(list.len()?); - for item in iter { - avs.push(item?.extract::>()?.0) - } + let (dtype, n_types) = any_values_to_dtype(&avs).map_err(PyPolarsErr::from)?; - let s = Series::from_any_values_and_dtype("", &avs, &dtype, true) - .map_err(PyPolarsErr::from)?; - Ok(Wrap(AnyValue::List(s))) - } - } else { - // range will take this branch + // we only take this path if there is no question of the data-type + if dtype.is_primitive() && n_types == 1 { get_list_with_constructor(ob) + } else { + // push the rest + avs.reserve(list.len()?); + for item in iter { + let av = py_object_to_any_value(item?, strict)?; + avs.push(av) + } + + let s = Series::from_any_values_and_dtype("", &avs, &dtype, strict) + .map_err(PyPolarsErr::from)?; + Ok(AnyValue::List(s)) } + } else { + // range will take this branch + get_list_with_constructor(ob) } + } - fn get_series_el(ob: &PyAny) -> PyResult>> { - let py_pyseries = ob.getattr(intern!(ob.py(), "_s")).unwrap(); - let series = py_pyseries.extract::().unwrap().series; - Ok(Wrap(AnyValue::List(series))) - } + fn get_series_el(ob: &PyAny, _strict: bool) -> PyResult> { + let py_pyseries = ob.getattr(intern!(ob.py(), "_s")).unwrap(); + let series = py_pyseries.extract::().unwrap().series; + Ok(AnyValue::List(series)) + } - fn get_bin(ob: &PyAny) -> PyResult> { - let value = ob.extract::<&[u8]>().unwrap(); - Ok(AnyValue::Binary(value).into()) - } + fn get_bin(ob: &PyAny, _strict: bool) -> PyResult { + let value = ob.extract::<&[u8]>().unwrap(); + Ok(AnyValue::Binary(value)) + } - fn get_null(_ob: &PyAny) -> PyResult> { - Ok(AnyValue::Null.into()) - } + fn get_null(_ob: &PyAny, _strict: bool) -> PyResult { + Ok(AnyValue::Null) + } - fn get_timedelta(ob: &PyAny) -> PyResult> { - Python::with_gil(|py| { - let td = UTILS - .as_ref(py) - .getattr(intern!(py, "_timedelta_to_pl_timedelta")) - .unwrap() - .call1((ob, intern!(py, "us"))) - .unwrap(); - let v = td.extract::().unwrap(); - Ok(Wrap(AnyValue::Duration(v, TimeUnit::Microseconds))) - }) - } + fn get_date(ob: &PyAny, _strict: bool) -> PyResult { + Python::with_gil(|py| { + let date = UTILS + .as_ref(py) + .getattr(intern!(py, "date_to_int")) + .unwrap() + .call1((ob,)) + .unwrap(); + let v = date.extract::().unwrap(); + Ok(AnyValue::Date(v)) + }) + } - fn get_time(ob: &PyAny) -> PyResult> { - Python::with_gil(|py| { - let time = UTILS - .as_ref(py) - .getattr(intern!(py, "_time_to_pl_time")) - .unwrap() - .call1((ob,)) - .unwrap(); - let v = time.extract::().unwrap(); - Ok(Wrap(AnyValue::Time(v))) - }) - } + fn get_datetime(ob: &PyAny, _strict: bool) -> PyResult { + Python::with_gil(|py| { + let date = UTILS + .as_ref(py) + .getattr(intern!(py, "datetime_to_int")) + .unwrap() + .call1((ob, intern!(py, "us"))) + .unwrap(); + let v = date.extract::().unwrap(); + Ok(AnyValue::Datetime(v, TimeUnit::Microseconds, &None)) + }) + } - fn get_decimal(ob: &PyAny) -> PyResult> { - let (sign, digits, exp): (i8, Vec, i32) = ob - .call_method0(intern!(ob.py(), "as_tuple")) + fn get_timedelta(ob: &PyAny, _strict: bool) -> PyResult { + Python::with_gil(|py| { + let td = UTILS + .as_ref(py) + .getattr(intern!(py, "timedelta_to_int")) .unwrap() - .extract() + .call1((ob, intern!(py, "us"))) .unwrap(); - // note: using Vec is not the most efficient thing here (input is a tuple) - let (mut v, scale) = abs_decimal_from_digits(digits, exp).ok_or_else(|| { - PyErr::from(PyPolarsErr::Other( - "Decimal is too large to fit in Decimal128".into(), - )) - })?; - if sign > 0 { - v = -v; // won't overflow since -i128::MAX > i128::MIN + let v = td.extract::().unwrap(); + Ok(AnyValue::Duration(v, TimeUnit::Microseconds)) + }) + } + + fn get_time(ob: &PyAny, _strict: bool) -> PyResult { + Python::with_gil(|py| { + let time = UTILS + .as_ref(py) + .getattr(intern!(py, "time_to_int")) + .unwrap() + .call1((ob,)) + .unwrap(); + let v = time.extract::().unwrap(); + Ok(AnyValue::Time(v)) + }) + } + + fn get_decimal(ob: &PyAny, _strict: bool) -> PyResult { + fn abs_decimal_from_digits( + digits: impl IntoIterator, + exp: i32, + ) -> Option<(i128, usize)> { + const MAX_ABS_DEC: i128 = 10_i128.pow(38) - 1; + let mut v = 0_i128; + for (i, d) in digits.into_iter().map(i128::from).enumerate() { + if i < 38 { + v = v * 10 + d; + } else { + v = v.checked_mul(10).and_then(|v| v.checked_add(d))?; + } } - Ok(Wrap(AnyValue::Decimal(v, scale))) + // we only support non-negative scale (=> non-positive exponent) + let scale = if exp > 0 { + // the decimal may be in a non-canonical representation, try to fix it first + v = 10_i128 + .checked_pow(exp as u32) + .and_then(|factor| v.checked_mul(factor))?; + 0 + } else { + (-exp) as usize + }; + // TODO: do we care for checking if it fits in MAX_ABS_DEC? (if we set precision to None anyway?) + (v <= MAX_ABS_DEC).then_some((v, scale)) } - fn get_object(ob: &PyAny) -> PyResult> { - #[cfg(feature = "object")] - { - // this is slow, but hey don't use objects - let v = &ObjectValue { inner: ob.into() }; - Ok(Wrap(AnyValue::ObjectOwned(OwnedObject(v.to_boxed())))) - } - #[cfg(not(feature = "object"))] - { - panic!("activate object") - } + let (sign, digits, exp): (i8, Vec, i32) = ob + .call_method0(intern!(ob.py(), "as_tuple")) + .unwrap() + .extract() + .unwrap(); + // note: using Vec is not the most efficient thing here (input is a tuple) + let (mut v, scale) = abs_decimal_from_digits(digits, exp).ok_or_else(|| { + PyErr::from(PyPolarsErr::Other( + "Decimal is too large to fit in Decimal128".into(), + )) + })?; + if sign > 0 { + v = -v; // won't overflow since -i128::MAX > i128::MIN } + Ok(AnyValue::Decimal(v, scale)) + } - // TYPE key - let type_object_ptr = PyType::as_type_ptr(ob.get_type()) as usize; + fn get_object(ob: &PyAny, _strict: bool) -> PyResult { + #[cfg(feature = "object")] + { + // this is slow, but hey don't use objects + let v = &ObjectValue { inner: ob.into() }; + Ok(AnyValue::ObjectOwned(OwnedObject(v.to_boxed()))) + } + #[cfg(not(feature = "object"))] + { + panic!("activate object") + } + } - Python::with_gil(|py| { - LUT.with_gil(py, |lut| { - // get the conversion function - let convert_fn = lut.entry(type_object_ptr).or_insert_with( - // This only runs if type is not in LUT - || { - if ob.is_instance_of::() { - get_bool - // TODO: this heap allocs on failure - } else if ob.extract::().is_ok() || ob.extract::().is_ok() { - get_int - } else if ob.is_instance_of::() { - get_float - } else if ob.is_instance_of::() { - get_str - } else if ob.is_instance_of::() { - get_struct - } else if ob.is_instance_of::() || ob.is_instance_of::() { - get_list - } else if ob.hasattr(intern!(py, "_s")).unwrap() { - get_series_el - } - // TODO: this heap allocs on failure - else if ob.extract::<&'s [u8]>().is_ok() { - get_bin - } else if ob.is_none() { - get_null - } else { - let type_name = ob.get_type().name().unwrap(); - match type_name { - "datetime" => convert_datetime, - "date" => convert_date, - "timedelta" => get_timedelta, - "time" => get_time, - "Decimal" => get_decimal, - "range" => get_list, - _ => { - // special branch for np.float as this fails isinstance float - if ob.extract::().is_ok() { - return get_float; - } + // TYPE key + let type_object_ptr = PyType::as_type_ptr(ob.get_type()) as usize; - // Can't use pyo3::types::PyDateTime with abi3-py37 feature, - // so need this workaround instead of `isinstance(ob, datetime)`. - let bases = ob - .get_type() - .getattr(intern!(py, "__bases__")) - .unwrap() - .iter() - .unwrap(); - for base in bases { - let parent_type = - base.unwrap().str().unwrap().to_str().unwrap(); - match parent_type { - "" => { - // `datetime.datetime` is a subclass of `datetime.date`, - // so need to check `datetime.datetime` first - return convert_datetime; - }, - "" => { - return convert_date; - }, - _ => (), - } + Python::with_gil(|py| { + LUT.with_gil(py, |lut| { + // get the conversion function + let convert_fn = lut.entry(type_object_ptr).or_insert_with( + // This only runs if type is not in LUT + || { + if ob.is_instance_of::() { + get_bool + // TODO: this heap allocs on failure + } else if ob.extract::().is_ok() || ob.extract::().is_ok() { + get_int + } else if ob.is_instance_of::() { + get_float + } else if ob.is_instance_of::() { + get_str + } else if ob.is_instance_of::() { + get_struct + } else if ob.is_instance_of::() || ob.is_instance_of::() { + get_list + } else if ob.hasattr(intern!(py, "_s")).unwrap() { + get_series_el + } + // TODO: this heap allocs on failure + else if ob.extract::<&[u8]>().is_ok() { + get_bin + } else if ob.is_none() { + get_null + } else { + let type_name = ob.get_type().name().unwrap(); + match type_name { + "datetime" => get_datetime, + "date" => get_date, + "timedelta" => get_timedelta, + "time" => get_time, + "Decimal" => get_decimal, + "range" => get_list, + _ => { + // special branch for np.float as this fails isinstance float + if ob.extract::().is_ok() { + return get_float; + } + + // Can't use pyo3::types::PyDateTime with abi3-py37 feature, + // so need this workaround instead of `isinstance(ob, datetime)`. + let bases = ob + .get_type() + .getattr(intern!(py, "__bases__")) + .unwrap() + .iter() + .unwrap(); + for base in bases { + let parent_type = + base.unwrap().str().unwrap().to_str().unwrap(); + match parent_type { + "" => { + // `datetime.datetime` is a subclass of `datetime.date`, + // so need to check `datetime.datetime` first + return get_datetime; + }, + "" => { + return get_date; + }, + _ => (), } + } - get_object - }, - } + get_object + }, } - }, - ); + } + }, + ); - convert_fn(ob) - }) + convert_fn(ob, strict) }) - } -} - -fn convert_date(ob: &PyAny) -> PyResult> { - Python::with_gil(|py| { - let date = UTILS - .as_ref(py) - .getattr(intern!(py, "_date_to_pl_date")) - .unwrap() - .call1((ob,)) - .unwrap(); - let v = date.extract::().unwrap(); - Ok(Wrap(AnyValue::Date(v))) }) } -fn convert_datetime(ob: &PyAny) -> PyResult> { - Python::with_gil(|py| { - // windows - #[cfg(target_arch = "windows")] - let (seconds, microseconds) = { - let convert = UTILS - .getattr(py, intern!(py, "_datetime_for_any_value_windows")) - .unwrap(); - let out = convert.call1(py, (ob,)).unwrap(); - let out: (i64, i64) = out.extract(py).unwrap(); - out - }; - // unix - #[cfg(not(target_arch = "windows"))] - let (seconds, microseconds) = { - let convert = UTILS - .getattr(py, intern!(py, "_datetime_for_any_value")) - .unwrap(); - let out = convert.call1(py, (ob,)).unwrap(); - let out: (i64, i64) = out.extract(py).unwrap(); - out - }; - - // s to us - let mut v = seconds * 1_000_000; - v += microseconds; - - // choose "us" as that is python's default unit - Ok(AnyValue::Datetime(v, TimeUnit::Microseconds, &None).into()) - }) -} - -fn abs_decimal_from_digits( - digits: impl IntoIterator, - exp: i32, -) -> Option<(i128, usize)> { - const MAX_ABS_DEC: i128 = 10_i128.pow(38) - 1; - let mut v = 0_i128; - for (i, d) in digits.into_iter().map(i128::from).enumerate() { - if i < 38 { - v = v * 10 + d; - } else { - v = v.checked_mul(10).and_then(|v| v.checked_add(d))?; - } - } - // we only support non-negative scale (=> non-positive exponent) - let scale = if exp > 0 { - // the decimal may be in a non-canonical representation, try to fix it first - v = 10_i128 - .checked_pow(exp as u32) - .and_then(|factor| v.checked_mul(factor))?; - 0 - } else { - (-exp) as usize - }; - // TODO: do we care for checking if it fits in MAX_ABS_DEC? (if we set precision to None anyway?) - (v <= MAX_ABS_DEC).then_some((v, scale)) -} diff --git a/py-polars/src/conversion/chunked_array.rs b/py-polars/src/conversion/chunked_array.rs index 4831890931255..c2e65a0613dbe 100644 --- a/py-polars/src/conversion/chunked_array.rs +++ b/py-polars/src/conversion/chunked_array.rs @@ -1,9 +1,6 @@ -use polars::prelude::AnyValue; -#[cfg(feature = "cloud")] -use pyo3::conversion::{FromPyObject, IntoPy}; +use pyo3::intern; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyList, PyTuple}; -use pyo3::{intern, PyAny, PyResult}; use super::{decimal_to_digits, struct_dict}; use crate::prelude::*; @@ -115,12 +112,12 @@ impl ToPyObject for Wrap<&StructChunked> { impl ToPyObject for Wrap<&DurationChunked> { fn to_object(&self, py: Python) -> PyObject { let utils = UTILS.as_ref(py); - let convert = utils.getattr(intern!(py, "_to_python_timedelta")).unwrap(); - let time_unit = Wrap(self.0.time_unit()).to_object(py); + let convert = utils.getattr(intern!(py, "to_py_timedelta")).unwrap(); + let time_unit = self.0.time_unit().to_ascii(); let iter = self .0 .into_iter() - .map(|opt_v| opt_v.map(|v| convert.call1((v, &time_unit)).unwrap())); + .map(|opt_v| opt_v.map(|v| convert.call1((v, time_unit)).unwrap())); PyList::new(py, iter).into_py(py) } } @@ -128,13 +125,13 @@ impl ToPyObject for Wrap<&DurationChunked> { impl ToPyObject for Wrap<&DatetimeChunked> { fn to_object(&self, py: Python) -> PyObject { let utils = UTILS.as_ref(py); - let convert = utils.getattr(intern!(py, "_to_python_datetime")).unwrap(); - let time_unit = Wrap(self.0.time_unit()).to_object(py); + let convert = utils.getattr(intern!(py, "to_py_datetime")).unwrap(); + let time_unit = self.0.time_unit().to_ascii(); let time_zone = self.0.time_zone().to_object(py); let iter = self .0 .into_iter() - .map(|opt_v| opt_v.map(|v| convert.call1((v, &time_unit, &time_zone)).unwrap())); + .map(|opt_v| opt_v.map(|v| convert.call1((v, time_unit, &time_zone)).unwrap())); PyList::new(py, iter).into_py(py) } } @@ -151,7 +148,7 @@ pub(crate) fn time_to_pyobject_iter<'a>( ca: &'a TimeChunked, ) -> impl ExactSizeIterator> { let utils = UTILS.as_ref(py); - let convert = utils.getattr(intern!(py, "_to_python_time")).unwrap(); + let convert = utils.getattr(intern!(py, "to_py_time")).unwrap(); ca.0.into_iter() .map(|opt_v| opt_v.map(|v| convert.call1((v,)).unwrap())) } @@ -159,7 +156,7 @@ pub(crate) fn time_to_pyobject_iter<'a>( impl ToPyObject for Wrap<&DateChunked> { fn to_object(&self, py: Python) -> PyObject { let utils = UTILS.as_ref(py); - let convert = utils.getattr(intern!(py, "_to_python_date")).unwrap(); + let convert = utils.getattr(intern!(py, "to_py_date")).unwrap(); let iter = self .0 .into_iter() @@ -180,7 +177,7 @@ pub(crate) fn decimal_to_pyobject_iter<'a>( ca: &'a DecimalChunked, ) -> impl ExactSizeIterator> { let utils = UTILS.as_ref(py); - let convert = utils.getattr(intern!(py, "_to_python_decimal")).unwrap(); + let convert = utils.getattr(intern!(py, "to_py_decimal")).unwrap(); let py_scale = (-(ca.scale() as i32)).to_object(py); // if we don't know precision, the only safe bet is to set it to 39 let py_precision = ca.precision().unwrap_or(39).to_object(py); diff --git a/py-polars/src/conversion/mod.rs b/py-polars/src/conversion/mod.rs index 59da00e0078de..4d08a4635e711 100644 --- a/py-polars/src/conversion/mod.rs +++ b/py-polars/src/conversion/mod.rs @@ -10,23 +10,18 @@ use polars::frame::row::Row; use polars::frame::NullStrategy; #[cfg(feature = "avro")] use polars::io::avro::AvroCompression; -#[cfg(feature = "ipc")] -use polars::io::ipc::IpcCompression; -use polars::prelude::AnyValue; use polars::series::ops::NullBehavior; -use polars_core::prelude::{IndexOrder, QuantileInterpolOptions}; use polars_core::utils::arrow::array::Array; use polars_core::utils::arrow::types::NativeType; use polars_lazy::prelude::*; #[cfg(feature = "cloud")] use polars_rs::io::cloud::CloudOptions; -use polars_utils::total_ord::TotalEq; +use polars_utils::total_ord::{TotalEq, TotalHash}; use pyo3::basic::CompareOp; -use pyo3::conversion::{FromPyObject, IntoPy}; use pyo3::exceptions::{PyTypeError, PyValueError}; +use pyo3::intern; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList, PySequence}; -use pyo3::{intern, PyAny, PyResult}; use smartstring::alias::String as SmartString; use crate::error::PyPolarsErr; @@ -440,6 +435,7 @@ impl ToPyObject for Wrap { impl<'s> FromPyObject<'s> for Wrap> { fn extract(ob: &'s PyAny) -> PyResult { let vals = ob.extract::>>>()?; + // SAFETY. Wrap is repr transparent. let vals: Vec = unsafe { std::mem::transmute(vals) }; Ok(Wrap(Row(vals))) } @@ -498,6 +494,15 @@ impl TotalEq for ObjectValue { } } +impl TotalHash for ObjectValue { + fn tot_hash(&self, state: &mut H) + where + H: Hasher, + { + self.hash(state); + } +} + impl Display for ObjectValue { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.inner) diff --git a/py-polars/src/dataframe.rs b/py-polars/src/dataframe.rs index 12e28ff37aecd..b357954eb98fb 100644 --- a/py-polars/src/dataframe.rs +++ b/py-polars/src/dataframe.rs @@ -4,16 +4,12 @@ use std::ops::Deref; use either::Either; use polars::frame::row::{rows_to_schema_supertypes, Row}; -use polars::frame::NullStrategy; #[cfg(feature = "avro")] use polars::io::avro::AvroCompression; -#[cfg(feature = "ipc")] -use polars::io::ipc::IpcCompression; use polars::io::mmap::ReaderBytes; use polars::io::RowIndex; use polars::prelude::*; use polars_core::export::arrow::datatypes::IntegerType; -use polars_core::frame::explode::MeltArgs; use polars_core::frame::*; use polars_core::utils::arrow::compute::cast::CastOptions; #[cfg(feature = "pivot")] @@ -1222,11 +1218,12 @@ impl PyDataFrame { } #[cfg(feature = "pivot")] + #[pyo3(signature = (index, columns, values, maintain_order, sort_columns, aggregate_expr, separator))] pub fn pivot_expr( &self, - values: Vec, index: Vec, columns: Vec, + values: Option>, maintain_order: bool, sort_columns: bool, aggregate_expr: Option, @@ -1236,9 +1233,9 @@ impl PyDataFrame { let agg_expr = aggregate_expr.map(|expr| expr.inner); let df = fun( &self.df, - values, index, columns, + values, sort_columns, agg_expr, separator, diff --git a/py-polars/src/expr/general.rs b/py-polars/src/expr/general.rs index fc82f4ef9b024..269236b7b2d4e 100644 --- a/py-polars/src/expr/general.rs +++ b/py-polars/src/expr/general.rs @@ -3,7 +3,6 @@ use std::ops::Neg; use polars::lazy::dsl; use polars::prelude::*; use polars::series::ops::NullBehavior; -use polars_core::prelude::QuantileInterpolOptions; use polars_core::series::IsSorted; use pyo3::class::basic::CompareOp; use pyo3::prelude::*; @@ -864,54 +863,6 @@ impl PyExpr { .into() } - #[cfg(feature = "ffi_plugin")] - fn register_plugin( - &self, - lib: &str, - symbol: &str, - args: Vec, - kwargs: Vec, - is_elementwise: bool, - input_wildcard_expansion: bool, - returns_scalar: bool, - cast_to_supertypes: bool, - pass_name_to_apply: bool, - changes_length: bool, - ) -> PyResult { - use polars_plan::prelude::*; - let inner = self.inner.clone(); - - let collect_groups = if is_elementwise { - ApplyOptions::ElementWise - } else { - ApplyOptions::GroupWise - }; - let mut input = Vec::with_capacity(args.len() + 1); - input.push(inner); - for a in args { - input.push(a.inner) - } - - Ok(Expr::Function { - input, - function: FunctionExpr::FfiPlugin { - lib: Arc::from(lib), - symbol: Arc::from(symbol), - kwargs: Arc::from(kwargs), - }, - options: FunctionOptions { - collect_groups, - input_wildcard_expansion, - returns_scalar, - cast_to_supertypes, - pass_name_to_apply, - changes_length, - ..Default::default() - }, - } - .into()) - } - #[cfg(feature = "hist")] #[pyo3(signature = (bins, bin_count, include_category, include_breakpoint))] fn hist( diff --git a/py-polars/src/expr/list.rs b/py-polars/src/expr/list.rs index 9f3a713e013e3..fde544a6ce41a 100644 --- a/py-polars/src/expr/list.rs +++ b/py-polars/src/expr/list.rs @@ -1,4 +1,3 @@ -use polars::lazy::dsl::lit; use polars::prelude::*; use polars::series::ops::NullBehavior; use pyo3::prelude::*; diff --git a/py-polars/src/expr/meta.rs b/py-polars/src/expr/meta.rs index 658ebc6329a21..62af2805c203b 100644 --- a/py-polars/src/expr/meta.rs +++ b/py-polars/src/expr/meta.rs @@ -89,7 +89,7 @@ impl PyExpr { } #[cfg(all(feature = "json", feature = "serde_json"))] - fn meta_write_json(&self, py_f: PyObject) -> PyResult<()> { + fn serialize(&self, py_f: PyObject) -> PyResult<()> { let file = BufWriter::new(get_file_like(py_f, true)?); serde_json::to_writer(file, &self.inner) .map_err(|err| PyValueError::new_err(format!("{err:?}")))?; @@ -97,17 +97,28 @@ impl PyExpr { } #[staticmethod] - fn meta_read_json(value: &str) -> PyResult { - #[cfg(feature = "json")] - { - let inner: polars_lazy::prelude::Expr = serde_json::from_str(value) - .map_err(|_| PyPolarsErr::from(polars_err!(ComputeError: "could not serialize")))?; - Ok(PyExpr { inner }) - } - #[cfg(not(feature = "json"))] - { - panic!("activate 'json' feature") - } + #[cfg(feature = "json")] + fn deserialize(py_f: PyObject) -> PyResult { + // it is faster to first read to memory and then parse: https://github.com/serde-rs/json/issues/160 + // so don't bother with files. + let mut json = String::new(); + let _ = get_file_like(py_f, false)? + .read_to_string(&mut json) + .unwrap(); + + // SAFETY: + // we skipped the serializing/deserializing of the static in lifetime in `DataType` + // so we actually don't have a lifetime at all when serializing. + + // &str still has a lifetime. But it's ok, because we drop it immediately + // in this scope + let json = unsafe { std::mem::transmute::<&'_ str, &'static str>(json.as_str()) }; + + let inner: polars_lazy::prelude::Expr = serde_json::from_str(json).map_err(|_| { + let msg = "could not deserialize input into an expression"; + PyPolarsErr::from(polars_err!(ComputeError: msg)) + })?; + Ok(PyExpr { inner }) } fn meta_tree_format(&self) -> PyResult { diff --git a/py-polars/src/expr/name.rs b/py-polars/src/expr/name.rs index 8c3479e40a324..821cab8fbefb4 100644 --- a/py-polars/src/expr/name.rs +++ b/py-polars/src/expr/name.rs @@ -1,5 +1,4 @@ use polars::prelude::*; -use polars_plan::dsl::FieldsNameMapper; use pyo3::prelude::*; use smartstring::alias::String as SmartString; diff --git a/py-polars/src/expr/rolling.rs b/py-polars/src/expr/rolling.rs index 917aa4936ad91..128596bb13a9d 100644 --- a/py-polars/src/expr/rolling.rs +++ b/py-polars/src/expr/rolling.rs @@ -1,7 +1,6 @@ use std::any::Any; use polars::prelude::*; -use polars_core::prelude::QuantileInterpolOptions; use pyo3::prelude::*; use pyo3::types::PyFloat; diff --git a/py-polars/src/functions/lazy.rs b/py-polars/src/functions/lazy.rs index c1f33c9cfbca0..704658f9d8ccf 100644 --- a/py-polars/src/functions/lazy.rs +++ b/py-polars/src/functions/lazy.rs @@ -1,5 +1,4 @@ use polars::lazy::dsl; -use polars::lazy::dsl::Expr; use polars::prelude::*; use pyo3::exceptions::PyTypeError; use pyo3::prelude::*; @@ -8,7 +7,7 @@ use pyo3::types::{PyBool, PyBytes, PyFloat, PyInt, PyString}; use crate::conversion::{get_lf, Wrap}; use crate::expr::ToExprs; use crate::map::lazy::binary_lambda; -use crate::prelude::{vec_extract_wrapped, DataType, DatetimeArgs, DurationArgs, ObjectValue}; +use crate::prelude::{vec_extract_wrapped, ObjectValue}; use crate::{map, PyDataFrame, PyExpr, PyLazyFrame, PyPolarsErr, PySeries}; macro_rules! set_unwrapped_or_0 { diff --git a/py-polars/src/functions/misc.rs b/py-polars/src/functions/misc.rs index 593244618f03f..8c4116e988323 100644 --- a/py-polars/src/functions/misc.rs +++ b/py-polars/src/functions/misc.rs @@ -1,10 +1,55 @@ +use std::sync::Arc; + +use polars_plan::prelude::*; use pyo3::prelude::*; use crate::conversion::Wrap; +use crate::expr::ToExprs; use crate::prelude::DataType; +use crate::PyExpr; #[pyfunction] pub fn dtype_str_repr(dtype: Wrap) -> PyResult { let dtype = dtype.0; Ok(dtype.to_string()) } + +#[cfg(feature = "ffi_plugin")] +#[pyfunction] +pub fn register_plugin_function( + plugin_path: &str, + function_name: &str, + args: Vec, + kwargs: Vec, + is_elementwise: bool, + input_wildcard_expansion: bool, + returns_scalar: bool, + cast_to_supertype: bool, + pass_name_to_apply: bool, + changes_length: bool, +) -> PyResult { + let collect_groups = if is_elementwise { + ApplyOptions::ElementWise + } else { + ApplyOptions::GroupWise + }; + + Ok(Expr::Function { + input: args.to_exprs(), + function: FunctionExpr::FfiPlugin { + lib: Arc::from(plugin_path), + symbol: Arc::from(function_name), + kwargs: Arc::from(kwargs), + }, + options: FunctionOptions { + collect_groups, + input_wildcard_expansion, + returns_scalar, + cast_to_supertypes: cast_to_supertype, + pass_name_to_apply, + changes_length, + ..Default::default() + }, + } + .into()) +} diff --git a/py-polars/src/lazyframe/mod.rs b/py-polars/src/lazyframe/mod.rs index ddc03855ca97b..b1c3f6b80d393 100644 --- a/py-polars/src/lazyframe/mod.rs +++ b/py-polars/src/lazyframe/mod.rs @@ -6,30 +6,15 @@ use std::num::NonZeroUsize; use std::path::PathBuf; pub use exitable::PyInProcessQuery; -#[cfg(feature = "csv")] -use polars::io::csv::SerializeOptions; use polars::io::RowIndex; -#[cfg(feature = "csv")] -use polars::lazy::frame::LazyCsvReader; -#[cfg(feature = "json")] -use polars::lazy::frame::LazyJsonLineReader; -use polars::lazy::frame::{AllowedOptimizations, LazyFrame}; -use polars::lazy::prelude::col; -#[cfg(feature = "csv")] -use polars::prelude::CsvEncoding; -use polars::prelude::{ClosedWindow, Field, JoinType, Schema}; use polars::time::*; -use polars_core::frame::explode::MeltArgs; -use polars_core::frame::UniqueKeepStrategy; use polars_core::prelude::*; -use polars_ops::prelude::AsOfOptions; use polars_rs::io::cloud::CloudOptions; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyDict, PyList}; use crate::arrow_interop::to_rust::pyarrow_schema_to_rust; -use crate::conversion::Wrap; use crate::error::PyPolarsErr; use crate::expr::ToExprs; use crate::file::get_file_like; @@ -319,7 +304,7 @@ impl PyLazyFrame { #[cfg(feature = "ipc")] #[staticmethod] - #[pyo3(signature = (path, paths, n_rows, cache, rechunk, row_index, memory_map))] + #[pyo3(signature = (path, paths, n_rows, cache, rechunk, row_index, memory_map, cloud_options, retries))] fn new_from_ipc( path: Option, paths: Vec, @@ -328,14 +313,45 @@ impl PyLazyFrame { rechunk: bool, row_index: Option<(String, IdxSize)>, memory_map: bool, + cloud_options: Option>, + retries: usize, ) -> PyResult { let row_index = row_index.map(|(name, offset)| RowIndex { name, offset }); + + #[cfg(feature = "cloud")] + let cloud_options = { + let first_path = if let Some(path) = &path { + path + } else { + paths + .first() + .ok_or_else(|| PyValueError::new_err("expected a path argument"))? + }; + + let first_path_url = first_path.to_string_lossy(); + let mut cloud_options = cloud_options + .map(|kv| parse_cloud_options(&first_path_url, kv)) + .transpose()?; + if retries > 0 { + cloud_options = + cloud_options + .or_else(|| Some(CloudOptions::default())) + .map(|mut options| { + options.max_retries = retries; + options + }); + } + cloud_options + }; + let args = ScanArgsIpc { n_rows, cache, rechunk, row_index, memmap: memory_map, + #[cfg(feature = "cloud")] + cloud_options, }; let lf = if let Some(path) = &path { diff --git a/py-polars/src/lib.rs b/py-polars/src/lib.rs index 1dcb20557e1aa..cdd1725f6f9e4 100644 --- a/py-polars/src/lib.rs +++ b/py-polars/src/lib.rs @@ -305,5 +305,9 @@ fn polars(py: Python, m: &PyModule) -> PyResult<()> { pyo3_built!(py, build, "build", "time", "deps", "features", "host", "target", "git"), )?; + // Plugins + m.add_wrapped(wrap_pyfunction!(functions::register_plugin_function)) + .unwrap(); + Ok(()) } diff --git a/py-polars/src/map/dataframe.rs b/py-polars/src/map/dataframe.rs index 16ac34120fa0c..52882fd5db3d8 100644 --- a/py-polars/src/map/dataframe.rs +++ b/py-polars/src/map/dataframe.rs @@ -1,14 +1,10 @@ use polars::prelude::*; use polars_core::frame::row::{rows_to_schema_first_non_null, Row}; use polars_core::series::SeriesIter; -use pyo3::conversion::{FromPyObject, IntoPy}; use pyo3::prelude::*; use pyo3::types::{PyBool, PyFloat, PyInt, PyList, PyString, PyTuple}; use super::*; -use crate::conversion::Wrap; -use crate::error::PyPolarsErr; -use crate::series::PySeries; use crate::PyDataFrame; fn get_iters(df: &DataFrame) -> Vec { diff --git a/py-polars/src/map/lazy.rs b/py-polars/src/map/lazy.rs index 243772fa8d366..75084783a2959 100644 --- a/py-polars/src/map/lazy.rs +++ b/py-polars/src/map/lazy.rs @@ -103,7 +103,7 @@ pub(crate) fn binary_lambda( let pyseries = if let Ok(expr) = result_series_wrapper.getattr(py, "_pyexpr") { let pyexpr = expr.extract::(py).unwrap(); let expr = pyexpr.inner; - let df = DataFrame::new_no_checks(vec![]); + let df = DataFrame::empty(); let out = df .lazy() .select([expr]) diff --git a/py-polars/src/map/series.rs b/py-polars/src/map/series.rs index d0a1e08b0f8e7..afe4172eb788f 100644 --- a/py-polars/src/map/series.rs +++ b/py-polars/src/map/series.rs @@ -1,13 +1,10 @@ -use polars::chunked_array::builder::get_list_builder; use polars::prelude::*; use pyo3::prelude::*; -use pyo3::types::{PyBool, PyCFunction, PyDict, PyFloat, PyList, PyString, PyTuple}; +use pyo3::types::{PyBool, PyCFunction, PyFloat, PyList, PyString, PyTuple}; use super::*; use crate::conversion::slice_to_wrapped; use crate::py_modules::SERIES; -use crate::series::PySeries; -use crate::{PyPolarsErr, Wrap}; /// Find the output type and dispatch to that implementation. fn infer_and_finish<'a, A: ApplyLambda<'a>>( @@ -126,9 +123,6 @@ fn infer_and_finish<'a, A: ApplyLambda<'a>>( pub trait ApplyLambda<'a> { fn apply_lambda_unknown(&'a self, _py: Python, _lambda: &'a PyAny) -> PyResult; - /// Apply a lambda that doesn't change output types - fn apply_lambda(&'a self, _py: Python, _lambda: &'a PyAny) -> PyResult; - // Used to store a struct type fn apply_to_struct( &'a self, @@ -251,11 +245,6 @@ impl<'a> ApplyLambda<'a> for BooleanChunked { .into()) } - fn apply_lambda(&'a self, py: Python, lambda: &'a PyAny) -> PyResult { - self.apply_lambda_with_bool_out_type(py, lambda, 0, None) - .map(|ca| PySeries::new(ca.into_series())) - } - fn apply_to_struct( &'a self, py: Python, @@ -547,11 +536,6 @@ where .into()) } - fn apply_lambda(&'a self, py: Python, lambda: &'a PyAny) -> PyResult { - self.apply_lambda_with_primitive_out_type::(py, lambda, 0, None) - .map(|ca| PySeries::new(ca.into_series())) - } - fn apply_to_struct( &'a self, py: Python, @@ -838,11 +822,6 @@ impl<'a> ApplyLambda<'a> for StringChunked { .into()) } - fn apply_lambda(&'a self, py: Python, lambda: &'a PyAny) -> PyResult { - let ca = self.apply_lambda_with_string_out_type(py, lambda, 0, None)?; - Ok(ca.into_series().into()) - } - fn apply_to_struct( &'a self, py: Python, @@ -1107,40 +1086,6 @@ impl<'a> ApplyLambda<'a> for StringChunked { } } -fn append_series( - pypolars: &PyModule, - builder: &mut (impl ListBuilderTrait + ?Sized), - lambda: &PyAny, - series: Series, -) -> PyResult<()> { - // create a PySeries struct/object for Python - let pyseries = PySeries::new(series); - // Wrap this PySeries object in the python side Series wrapper - let python_series_wrapper = pypolars - .getattr("wrap_s") - .unwrap() - .call1((pyseries,)) - .unwrap(); - // call the lambda en get a python side Series wrapper - let out = lambda.call1((python_series_wrapper,)); - match out { - Ok(out) => { - // unpack the wrapper in a PySeries - let py_pyseries = out - .getattr("_s") - .expect("could not get Series attribute '_s'"); - let pyseries = py_pyseries.extract::()?; - builder - .append_series(&pyseries.series) - .map_err(PyPolarsErr::from)?; - }, - Err(_) => { - builder.append_opt_series(None).map_err(PyPolarsErr::from)?; - }, - }; - Ok(()) -} - fn call_series_lambda(pypolars: &PyModule, lambda: &PyAny, series: Series) -> Option { // create a PySeries struct/object for Python let pyseries = PySeries::new(series); @@ -1196,74 +1141,6 @@ impl<'a> ApplyLambda<'a> for ListChunked { .into()) } - fn apply_lambda(&'a self, py: Python, lambda: &'a PyAny) -> PyResult { - // get the pypolars module - let pypolars = PyModule::import(py, "polars")?; - - match self.dtype() { - DataType::List(dt) => { - let mut builder = get_list_builder(dt, self.len() * 5, self.len(), self.name()) - .map_err(PyPolarsErr::from)?; - if !self.has_validity() { - let mut it = self.into_no_null_iter(); - // use first value to get dtype and replace default builder - if let Some(series) = it.next() { - let out_series = call_series_lambda(pypolars, lambda, series) - .expect("Cannot determine dtype because lambda failed; Make sure that your udf returns a Series"); - let dt = out_series.dtype(); - builder = get_list_builder(dt, self.len() * 5, self.len(), self.name()) - .map_err(PyPolarsErr::from)?; - builder - .append_opt_series(Some(&out_series)) - .map_err(PyPolarsErr::from)?; - } else { - let mut builder = - get_list_builder(dt, 0, 1, self.name()).map_err(PyPolarsErr::from)?; - let ca = builder.finish(); - return Ok(PySeries::new(ca.into_series())); - } - for series in it { - append_series(pypolars, &mut *builder, lambda, series)?; - } - } else { - let mut it = self.into_iter(); - let mut nulls = 0; - - // use first values to get dtype and replace default builders - // continue until no null is found - for opt_series in &mut it { - if let Some(series) = opt_series { - let out_series = call_series_lambda(pypolars, lambda, series) - .expect("Cannot determine dtype because lambda failed; Make sure that your udf returns a Series"); - let dt = out_series.dtype(); - builder = get_list_builder(dt, self.len() * 5, self.len(), self.name()) - .map_err(PyPolarsErr::from)?; - builder - .append_opt_series(Some(&out_series)) - .map_err(PyPolarsErr::from)?; - break; - } else { - nulls += 1; - } - } - for _ in 0..nulls { - builder.append_opt_series(None).map_err(PyPolarsErr::from)?; - } - for opt_series in it { - if let Some(series) = opt_series { - append_series(pypolars, &mut *builder, lambda, series)?; - } else { - builder.append_opt_series(None).unwrap() - } - } - }; - let ca = builder.finish(); - Ok(PySeries::new(ca.into_series())) - }, - _ => unimplemented!(), - } - } - fn apply_to_struct( &'a self, py: Python, @@ -1679,70 +1556,6 @@ impl<'a> ApplyLambda<'a> for ArrayChunked { .into()) } - fn apply_lambda(&'a self, py: Python, lambda: &'a PyAny) -> PyResult { - // get the pypolars module - let pypolars = PyModule::import(py, "polars")?; - - match self.dtype() { - DataType::List(dt) => { - let mut builder = get_list_builder(dt, self.len() * 5, self.len(), self.name()) - .map_err(PyPolarsErr::from)?; - if !self.has_validity() { - let mut it = self.into_no_null_iter(); - // use first value to get dtype and replace default builder - if let Some(series) = it.next() { - let out_series = call_series_lambda(pypolars, lambda, series) - .expect("Cannot determine dtype because lambda failed; Make sure that your udf returns a Series"); - let dt = out_series.dtype(); - builder = get_list_builder(dt, self.len() * 5, self.len(), self.name()) - .map_err(PyPolarsErr::from)?; - builder.append_opt_series(Some(&out_series)); - } else { - let mut builder = - get_list_builder(dt, 0, 1, self.name()).map_err(PyPolarsErr::from)?; - let ca = builder.finish(); - return Ok(PySeries::new(ca.into_series())); - } - for series in it { - append_series(pypolars, &mut *builder, lambda, series)?; - } - } else { - let mut it = self.into_iter(); - let mut nulls = 0; - - // use first values to get dtype and replace default builders - // continue until no null is found - for opt_series in &mut it { - if let Some(series) = opt_series { - let out_series = call_series_lambda(pypolars, lambda, series) - .expect("Cannot determine dtype because lambda failed; Make sure that your udf returns a Series"); - let dt = out_series.dtype(); - builder = get_list_builder(dt, self.len() * 5, self.len(), self.name()) - .map_err(PyPolarsErr::from)?; - builder.append_opt_series(Some(&out_series)); - break; - } else { - nulls += 1; - } - } - for _ in 0..nulls { - builder.append_opt_series(None); - } - for opt_series in it { - if let Some(series) = opt_series { - append_series(pypolars, &mut *builder, lambda, series)?; - } else { - builder.append_opt_series(None) - } - } - }; - let ca = builder.finish(); - Ok(PySeries::new(ca.into_series())) - }, - _ => unimplemented!(), - } - } - fn apply_to_struct( &'a self, py: Python, @@ -2149,18 +1962,6 @@ impl<'a> ApplyLambda<'a> for ObjectChunked { .into()) } - fn apply_lambda(&'a self, py: Python, lambda: &'a PyAny) -> PyResult { - #[cfg(feature = "object")] - { - self.apply_lambda_with_object_out_type(py, lambda, 0, None) - .map(|ca| PySeries::new(ca.into_series())) - } - #[cfg(not(feature = "object"))] - { - todo!() - } - } - fn apply_to_struct( &'a self, _py: Python, @@ -2447,10 +2248,6 @@ impl<'a> ApplyLambda<'a> for StructChunked { Ok(self.clone().into_series().into()) } - fn apply_lambda(&'a self, py: Python, lambda: &'a PyAny) -> PyResult { - self.apply_lambda_unknown(py, lambda) - } - fn apply_to_struct( &'a self, py: Python, diff --git a/py-polars/src/on_startup.rs b/py-polars/src/on_startup.rs index b592d45b675a1..f320d19d43344 100644 --- a/py-polars/src/on_startup.rs +++ b/py-polars/src/on_startup.rs @@ -1,20 +1,17 @@ use std::any::Any; -use std::sync::Arc; use polars::prelude::*; use polars_core::chunked_array::object::builder::ObjectChunkedBuilder; use polars_core::chunked_array::object::registry; use polars_core::chunked_array::object::registry::AnonymousObjectBuilder; use polars_core::error::PolarsError::ComputeError; -use polars_core::error::PolarsResult; -use polars_core::frame::DataFrame; use polars_error::PolarsWarning; use pyo3::intern; use pyo3::prelude::*; use crate::dataframe::PyDataFrame; use crate::map::lazy::{call_lambda_with_series, ToSeries}; -use crate::prelude::{python_udf, ObjectValue}; +use crate::prelude::ObjectValue; use crate::py_modules::{POLARS, UTILS}; use crate::Wrap; diff --git a/py-polars/src/py_modules.rs b/py-polars/src/py_modules.rs index 6c7dbd2658a18..7b0d370613497 100644 --- a/py-polars/src/py_modules.rs +++ b/py-polars/src/py_modules.rs @@ -5,7 +5,7 @@ pub(crate) static POLARS: Lazy = Lazy::new(|| Python::with_gil(|py| PyModule::import(py, "polars").unwrap().to_object(py))); pub(crate) static UTILS: Lazy = - Lazy::new(|| Python::with_gil(|py| POLARS.getattr(py, "utils").unwrap())); + Lazy::new(|| Python::with_gil(|py| POLARS.getattr(py, "_utils").unwrap())); pub(crate) static SERIES: Lazy = Lazy::new(|| Python::with_gil(|py| POLARS.getattr(py, "Series").unwrap())); diff --git a/py-polars/src/series/aggregation.rs b/py-polars/src/series/aggregation.rs index 9ed7819d56ac8..035ec49e136dc 100644 --- a/py-polars/src/series/aggregation.rs +++ b/py-polars/src/series/aggregation.rs @@ -101,7 +101,14 @@ impl PySeries { } fn product(&self, py: Python) -> PyResult { - Ok(Wrap(self.series.product().get(0).map_err(PyPolarsErr::from)?).into_py(py)) + Ok(Wrap( + self.series + .product() + .map_err(PyPolarsErr::from)? + .get(0) + .map_err(PyPolarsErr::from)?, + ) + .into_py(py)) } fn quantile( diff --git a/py-polars/src/series/construction.rs b/py-polars/src/series/construction.rs index c852be4f7edc6..6a8d6c736b65e 100644 --- a/py-polars/src/series/construction.rs +++ b/py-polars/src/series/construction.rs @@ -9,10 +9,10 @@ use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use crate::arrow_interop::to_rust::array_to_rust; +use crate::conversion::any_value::py_object_to_any_value; use crate::conversion::{slice_extract_wrapped, vec_extract_wrapped, Wrap}; use crate::error::PyPolarsErr; use crate::prelude::ObjectValue; -use crate::series::ToSeries; use crate::PySeries; // Init with numpy arrays. @@ -21,7 +21,7 @@ macro_rules! init_method { #[pymethods] impl PySeries { #[staticmethod] - fn $name(py: Python, name: &str, array: &PyArray1<$type>, _strict: bool) -> PySeries { + fn $name(py: Python, name: &str, array: &PyArray1<$type>, _strict: bool) -> Self { mmap_numpy_array(py, name, array) } } @@ -51,7 +51,7 @@ fn mmap_numpy_array( #[pymethods] impl PySeries { #[staticmethod] - fn new_bool(py: Python, name: &str, array: &PyArray1, _strict: bool) -> PySeries { + fn new_bool(py: Python, name: &str, array: &PyArray1, _strict: bool) -> Self { let array = array.readonly(); let vals = array.as_slice().unwrap(); py.allow_threads(|| PySeries { @@ -60,7 +60,7 @@ impl PySeries { } #[staticmethod] - fn new_f32(py: Python, name: &str, array: &PyArray1, nan_is_null: bool) -> PySeries { + fn new_f32(py: Python, name: &str, array: &PyArray1, nan_is_null: bool) -> Self { if nan_is_null { let array = array.readonly(); let vals = array.as_slice().unwrap(); @@ -75,7 +75,7 @@ impl PySeries { } #[staticmethod] - fn new_f64(py: Python, name: &str, array: &PyArray1, nan_is_null: bool) -> PySeries { + fn new_f64(py: Python, name: &str, array: &PyArray1, nan_is_null: bool) -> Self { if nan_is_null { let array = array.readonly(); let vals = array.as_slice().unwrap(); @@ -93,7 +93,7 @@ impl PySeries { #[pymethods] impl PySeries { #[staticmethod] - fn new_opt_bool(name: &str, obj: &PyAny, strict: bool) -> PyResult { + fn new_opt_bool(name: &str, obj: &PyAny, strict: bool) -> PyResult { let len = obj.len()?; let mut builder = BooleanChunkedBuilder::new(name, len); @@ -158,7 +158,7 @@ macro_rules! init_method_opt { #[pymethods] impl PySeries { #[staticmethod] - fn $name(name: &str, obj: &PyAny, strict: bool) -> PyResult { + fn $name(name: &str, obj: &PyAny, strict: bool) -> PyResult { new_primitive::<$type>(name, obj, strict) } } @@ -184,26 +184,46 @@ init_method_opt!(new_opt_f64, Float64Type, f64); )] impl PySeries { #[staticmethod] - fn new_from_any_values( - name: &str, - val: Vec>>, - strict: bool, - ) -> PyResult { - // From AnyValues is fallible. - let avs = slice_extract_wrapped(&val); - let s = Series::from_any_values(name, avs, strict).map_err(PyPolarsErr::from)?; - Ok(s.into()) + fn new_from_any_values(name: &str, values: Vec<&PyAny>, strict: bool) -> PyResult { + let any_values_result = values + .iter() + .map(|v| py_object_to_any_value(v, strict)) + .collect::>>(); + let result = any_values_result.and_then(|avs| { + let s = + Series::from_any_values(name, avs.as_slice(), strict).map_err(PyPolarsErr::from)?; + Ok(s.into()) + }); + + // Fall back to Object type for non-strict construction. + if !strict && result.is_err() { + let s = Python::with_gil(|py| { + let objects = values + .into_iter() + .map(|v| ObjectValue { + inner: v.to_object(py), + }) + .collect(); + Self::new_object(py, name, objects, strict) + }); + return Ok(s); + } + + result } #[staticmethod] fn new_from_any_values_and_dtype( name: &str, - val: Vec>>, + values: Vec<&PyAny>, dtype: Wrap, strict: bool, - ) -> PyResult { - let avs = slice_extract_wrapped(&val); - let s = Series::from_any_values_and_dtype(name, avs, &dtype.0, strict) + ) -> PyResult { + let any_values = values + .into_iter() + .map(|v| py_object_to_any_value(v, strict)) + .collect::>>()?; + let s = Series::from_any_values_and_dtype(name, any_values.as_slice(), &dtype.0, strict) .map_err(PyPolarsErr::from)?; Ok(s.into()) } @@ -244,14 +264,15 @@ impl PySeries { s.into() } #[cfg(not(feature = "object"))] - { - todo!() - } + panic!("activate 'object' feature") } #[staticmethod] - fn new_series_list(name: &str, val: Vec, _strict: bool) -> Self { - let series_vec = val.to_series(); + fn new_series_list(name: &str, val: Vec>, _strict: bool) -> Self { + let series_vec: Vec> = val + .iter() + .map(|v| v.as_ref().map(|py_s| py_s.clone().series)) + .collect(); Series::new(name, &series_vec).into() } @@ -265,40 +286,36 @@ impl PySeries { _strict: bool, ) -> PyResult { if val.is_empty() { - let series = - Series::new_empty(name, &DataType::Array(Box::new(inner.unwrap().0), width)); - Ok(series.into()) + let s = Series::new_empty(name, &DataType::Array(Box::new(inner.unwrap().0), width)); + return Ok(s.into()); + }; + + let val = vec_extract_wrapped(val); + let out = if let Some(inner) = inner { + Series::from_any_values_and_dtype( + name, + val.as_ref(), + &DataType::Array(Box::new(inner.0), width), + true, + ) + .map_err(PyPolarsErr::from)? } else { - let val = vec_extract_wrapped(val); - return if let Some(inner) = inner { - let series = Series::from_any_values_and_dtype( - name, - val.as_ref(), - &DataType::Array(Box::new(inner.0), width), - true, - ) - .map_err(PyPolarsErr::from)?; - Ok(series.into()) - } else { - let series = Series::new(name, &val); - match series.dtype() { - DataType::List(list_inner) => { - let series = series - .cast(&DataType::Array( - Box::new(inner.map(|dt| dt.0).unwrap_or(*list_inner.clone())), - width, - )) - .map_err(PyPolarsErr::from)?; - Ok(series.into()) - }, - _ => Err(PyValueError::new_err("could not create Array from input")), - } - }; - } + let series = Series::new(name, &val); + match series.dtype() { + DataType::List(list_inner) => series + .cast(&DataType::Array( + Box::new(inner.map(|dt| dt.0).unwrap_or(*list_inner.clone())), + width, + )) + .map_err(PyPolarsErr::from)?, + _ => return Err(PyValueError::new_err("could not create Array from input")), + } + }; + Ok(out.into()) } #[staticmethod] - fn new_decimal(name: &str, val: Vec>>, strict: bool) -> PyResult { + fn new_decimal(name: &str, val: Vec>>, strict: bool) -> PyResult { // TODO: do we have to respect 'strict' here? It's possible if we want to. let avs = slice_extract_wrapped(&val); // Create a fake dtype with a placeholder "none" scale, to be inferred later. diff --git a/py-polars/src/series/export.rs b/py-polars/src/series/export.rs index 71d84af1104fe..15d3247a863b1 100644 --- a/py-polars/src/series/export.rs +++ b/py-polars/src/series/export.rs @@ -6,7 +6,7 @@ use pyo3::types::PyList; use crate::conversion::chunked_array::{decimal_to_pyobject_iter, time_to_pyobject_iter}; use crate::error::PyPolarsErr; -use crate::prelude::{ObjectValue, *}; +use crate::prelude::*; use crate::{arrow_interop, raise_err, PySeries}; #[pymethods] diff --git a/py-polars/src/to_numpy.rs b/py-polars/src/to_numpy.rs index 70fd1fb74da16..53bb9ff014fa5 100644 --- a/py-polars/src/to_numpy.rs +++ b/py-polars/src/to_numpy.rs @@ -7,7 +7,6 @@ use polars_core::prelude::*; use polars_core::utils::try_get_supertype; use polars_core::with_match_physical_numeric_polars_type; use pyo3::prelude::*; -use pyo3::{IntoPy, PyAny, PyObject, Python}; use crate::conversion::Wrap; use crate::dataframe::PyDataFrame; @@ -49,10 +48,16 @@ where } #[pymethods] -#[allow(clippy::wrong_self_convention)] impl PySeries { + /// Create a view of the data as a NumPy ndarray. + /// + /// WARNING: The resulting view will show the underlying value for nulls, + /// which may be any value. The caller is responsible for handling nulls + /// appropriately. + #[allow(clippy::wrong_self_convention)] pub fn to_numpy_view(&self, py: Python) -> Option { - if self.series.null_count() != 0 || self.series.chunks().len() > 1 { + // NumPy arrays are always contiguous + if self.series.n_chunks() > 1 { return None; } @@ -62,15 +67,18 @@ impl PySeries { // Object to the series keep the memory alive. let owner = self.clone().into_py(py); with_match_physical_numeric_polars_type!(self.series.dtype(), |$T| { - let ca: &ChunkedArray<$T> = self.series.unpack::<$T>().unwrap(); - let slice = ca.cont_slice().unwrap(); - unsafe { Some(create_borrowed_np_array::<<$T as PolarsNumericType>::Native, _>( - py, - dims, - flags::NPY_ARRAY_FARRAY_RO, - slice.as_ptr() as _, - owner, - )) } + let ca: &ChunkedArray<$T> = self.series.unpack::<$T>().unwrap(); + let slice = ca.data_views().next().unwrap(); + let view = unsafe { + create_borrowed_np_array::<<$T as PolarsNumericType>::Native, _>( + py, + dims, + flags::NPY_ARRAY_FARRAY_RO, + slice.as_ptr() as _, + owner, + ) + }; + Some(view) }) }, _ => None, diff --git a/py-polars/tests/benchmark/test_release.py b/py-polars/tests/benchmark/test_release.py index 8fb6981202385..d9ad7ca49f2c3 100644 --- a/py-polars/tests/benchmark/test_release.py +++ b/py-polars/tests/benchmark/test_release.py @@ -5,6 +5,7 @@ To run these tests: pytest -m benchmark """ + import time from pathlib import Path from typing import cast diff --git a/py-polars/tests/docs/run_doctest.py b/py-polars/tests/docs/run_doctest.py index b4fffcca14fec..aa906a0456c91 100644 --- a/py-polars/tests/docs/run_doctest.py +++ b/py-polars/tests/docs/run_doctest.py @@ -26,6 +26,7 @@ all outputs, set `IGNORE_RESULT_ALL=True` below. Do note that this does mean no output is being checked anymore. """ + from __future__ import annotations import doctest diff --git a/py-polars/tests/docs/test_user_guide.py b/py-polars/tests/docs/test_user_guide.py index 3b17f7196c770..08be6fe9dfbfe 100644 --- a/py-polars/tests/docs/test_user_guide.py +++ b/py-polars/tests/docs/test_user_guide.py @@ -1,4 +1,5 @@ """Run all Python code snippets.""" + import os import runpy from pathlib import Path diff --git a/py-polars/tests/parametric/test_groupby_rolling.py b/py-polars/tests/parametric/test_groupby_rolling.py index e8af621d18481..39836c3880144 100644 --- a/py-polars/tests/parametric/test_groupby_rolling.py +++ b/py-polars/tests/parametric/test_groupby_rolling.py @@ -8,10 +8,10 @@ from hypothesis import assume, given import polars as pl +from polars._utils.convert import parse_as_duration_string from polars.testing import assert_frame_equal from polars.testing.parametric.primitives import column, dataframes from polars.testing.parametric.strategies import strategy_closed, strategy_time_unit -from polars.utils.convert import _timedelta_to_pl_duration if TYPE_CHECKING: from polars.type_aliases import ClosedInterval, TimeUnit @@ -20,10 +20,10 @@ @given( period=st.timedeltas( min_value=timedelta(microseconds=0), max_value=timedelta(days=1000) - ).map(_timedelta_to_pl_duration), + ).map(parse_as_duration_string), offset=st.timedeltas( min_value=timedelta(microseconds=0), max_value=timedelta(days=1000) - ).map(_timedelta_to_pl_duration), + ).map(parse_as_duration_string), closed=strategy_closed, data=st.data(), time_unit=strategy_time_unit, @@ -85,7 +85,7 @@ def test_rolling( @given( window_size=st.timedeltas( min_value=timedelta(microseconds=0), max_value=timedelta(days=2) - ).map(_timedelta_to_pl_duration), + ).map(parse_as_duration_string), closed=strategy_closed, data=st.data(), time_unit=strategy_time_unit, @@ -95,8 +95,8 @@ def test_rolling( "max", "mean", "sum", - # "std", blocked by https://github.com/pola-rs/polars/issues/11140 - # "var", blocked by https://github.com/pola-rs/polars/issues/11140 + "std", + "var", "median", ] ), diff --git a/py-polars/tests/parametric/test_series.py b/py-polars/tests/parametric/test_series.py index 27d4062afe767..4dedafaf888b5 100644 --- a/py-polars/tests/parametric/test_series.py +++ b/py-polars/tests/parametric/test_series.py @@ -3,95 +3,13 @@ # ------------------------------------------------- from __future__ import annotations -from typing import Any - from hypothesis import given, settings -from hypothesis.strategies import booleans, floats, sampled_from +from hypothesis.strategies import sampled_from import polars as pl -from polars.expr.expr import _prepare_alpha from polars.testing import assert_series_equal from polars.testing.parametric import series - -def alpha_guard(**decay_param: float) -> bool: - """Protects against unnecessary noise in small number regime.""" - if not next(iter(decay_param.values())): - return True - alpha = _prepare_alpha(**decay_param) - return ((1 - alpha) if round(alpha) else alpha) > 1e-6 - - -@given( - s=series( - min_size=4, - dtype=pl.Float64, - null_probability=0.05, - strategy=floats(min_value=-1e8, max_value=1e8), - ), - half_life=floats(min_value=0, max_value=4, exclude_min=True).filter( - lambda x: alpha_guard(half_life=x) - ), - com=floats(min_value=0, max_value=99).filter(lambda x: alpha_guard(com=x)), - span=floats(min_value=1, max_value=10).filter(lambda x: alpha_guard(span=x)), - ignore_nulls=booleans(), - adjust=booleans(), - bias=booleans(), -) -def test_ewm_methods( - s: pl.Series, - com: float | None, - span: float | None, - half_life: float | None, - ignore_nulls: bool, - adjust: bool, - bias: bool, -) -> None: - # validate a large set of varied EWM calculations - for decay_param in [{"com": com}, {"span": span}, {"half_life": half_life}]: - alpha = _prepare_alpha(**decay_param) - - # convert parametrically-generated series to pandas, then use that as a - # reference implementation for comparison (after normalising NaN/None) - p = s.to_pandas() - - # note: skip min_periods < 2, due to pandas-side inconsistency: - # https://github.com/pola-rs/polars/issues/5006#issuecomment-1259477178 - for mp in range(2, len(s), len(s) // 3): - # consolidate ewm parameters - pl_params: dict[str, Any] = { - "min_periods": mp, - "adjust": adjust, - "ignore_nulls": ignore_nulls, - } - pl_params.update(decay_param) - pd_params = pl_params.copy() - if "half_life" in pl_params: - pd_params["halflife"] = pd_params.pop("half_life") - if "ignore_nulls" in pl_params: - pd_params["ignore_na"] = pd_params.pop("ignore_nulls") - - # mean: - ewm_mean_pl = s.ewm_mean(**pl_params).fill_nan(None) - ewm_mean_pd = pl.Series(p.ewm(**pd_params).mean()) - if alpha == 1: - # apply fill-forward to nulls to match pandas - # https://github.com/pola-rs/polars/pull/5011#issuecomment-1262318124 - ewm_mean_pl = ewm_mean_pl.fill_null(strategy="forward") - - assert_series_equal(ewm_mean_pl, ewm_mean_pd, atol=1e-07) - - # std: - ewm_std_pl = s.ewm_std(bias=bias, **pl_params).fill_nan(None) - ewm_std_pd = pl.Series(p.ewm(**pd_params).std(bias=bias)) - assert_series_equal(ewm_std_pl, ewm_std_pd, atol=1e-07) - - # var: - ewm_var_pl = s.ewm_var(bias=bias, **pl_params).fill_nan(None) - ewm_var_pd = pl.Series(p.ewm(**pd_params).var(bias=bias)) - assert_series_equal(ewm_var_pl, ewm_var_pd, atol=1e-07) - - # TODO: once Decimal is a little further along, start actively probing it # @given( # s=series(max_size=10, dtype=pl.Decimal, null_probability=0.1), diff --git a/py-polars/tests/unit/constructors/test_any_value_fallbacks.py b/py-polars/tests/unit/constructors/test_any_value_fallbacks.py index bff14e5a461ed..1498197c7461e 100644 --- a/py-polars/tests/unit/constructors/test_any_value_fallbacks.py +++ b/py-polars/tests/unit/constructors/test_any_value_fallbacks.py @@ -8,13 +8,15 @@ import pytest import polars as pl +from polars._utils.wrap import wrap_s from polars.polars import PySeries -from polars.utils._wrap import wrap_s @pytest.mark.parametrize( ("dtype", "values"), [ + (pl.Int64, [-1, 0, 100_000, None]), + (pl.Float64, [-1.5, 0.0, 10.0, None]), (pl.Boolean, [True, False, None]), (pl.Binary, [b"123", b"xyz", None]), (pl.String, ["123", "xyz", None]), @@ -32,6 +34,8 @@ def test_fallback_with_dtype_strict( @pytest.mark.parametrize( ("dtype", "values"), [ + (pl.Int64, [1.0, 2.0]), + (pl.Float64, [1, 2]), (pl.Boolean, [0, 1]), (pl.Binary, ["123", "xyz"]), (pl.String, [b"123", b"xyz"]), @@ -41,16 +45,26 @@ def test_fallback_with_dtype_strict_failure( dtype: pl.PolarsDataType, values: list[Any] ) -> None: with pytest.raises(pl.SchemaError, match="unexpected value"): - PySeries.new_from_any_values_and_dtype("", values, pl.Boolean, strict=True) + PySeries.new_from_any_values_and_dtype("", values, dtype, strict=True) @pytest.mark.parametrize( ("dtype", "values", "expected"), [ + ( + pl.Int64, + [False, True, 0, -1, 0.0, 2.5, date(1970, 1, 2), "5", "xyz"], + [0, 1, 0, -1, 0, 2, 1, 5, None], + ), + ( + pl.Float64, + [False, True, 0, -1, 0.0, 2.5, date(1970, 1, 2), "5", "xyz"], + [0.0, 1.0, 0.0, -1.0, 0.0, 2.5, 1.0, 5.0, None], + ), ( pl.Boolean, - [False, True, 0, 1, 0.0, 2.5, date(1970, 1, 1)], - [False, True, False, True, False, True, None], + [False, True, 0, -1, 0.0, 2.5, date(1970, 1, 1), "true"], + [False, True, False, True, False, True, None, None], ), ( pl.Binary, @@ -71,3 +85,73 @@ def test_fallback_with_dtype_nonstrict( PySeries.new_from_any_values_and_dtype("", values, dtype, strict=False) ) assert result.to_list() == expected + + +@pytest.mark.parametrize( + ("values", "expected_dtype"), + [ + ([-1, 0, 100_000, None], pl.Int64), + ([-1.5, 0.0, 10.0, None], pl.Float64), + ([True, False, None], pl.Boolean), + ([b"123", b"xyz", None], pl.Binary), + (["123", "xyz", None], pl.String), + ], +) +def test_fallback_without_dtype_strict( + values: list[Any], expected_dtype: pl.PolarsDataType +) -> None: + result = wrap_s(PySeries.new_from_any_values("", values, strict=True)) + assert result.to_list() == values + + +@pytest.mark.parametrize( + "values", + [ + [1.0, 2], + [1, 2.0], + [False, 1], + [b"123", "xyz"], + ["123", b"xyz"], + ], +) +def test_fallback_without_dtype_strict_failure(values: list[Any]) -> None: + with pytest.raises(pl.SchemaError, match="unexpected value"): + PySeries.new_from_any_values("", values, strict=True) + + +@pytest.mark.parametrize( + ("values", "expected_dtype"), + [ + ([-1, 0, 100_000, None], pl.Int64), + ([-1.5, 0.0, 10.0, None], pl.Float64), + ([True, False, None], pl.Boolean), + ([b"123", b"xyz", None], pl.Binary), + (["123", "xyz", None], pl.String), + ], +) +def test_fallback_without_dtype_nonstrict_single_type( + values: list[Any], + expected_dtype: pl.PolarsDataType, +) -> None: + result = wrap_s(PySeries.new_from_any_values("", values, strict=False)) + assert result.dtype == expected_dtype + assert result.to_list() == values + + +@pytest.mark.parametrize( + ("values", "expected", "expected_dtype"), + [ + ([True, 2], [1, 2], pl.Int64), + ([1, 2.0], [1.0, 2.0], pl.Float64), + ([2.0, "c"], ["2.0", "c"], pl.String), + ([1, 2.0, b"d", date(2022, 1, 1)], [1, 2.0, b"d", date(2022, 1, 1)], pl.Object), + ], +) +def test_fallback_without_dtype_nonstrict_mixed_types( + values: list[Any], + expected_dtype: pl.PolarsDataType, + expected: list[Any], +) -> None: + result = wrap_s(PySeries.new_from_any_values("", values, strict=False)) + assert result.dtype == expected_dtype + assert result.to_list() == expected diff --git a/py-polars/tests/unit/constructors/test_constructors.py b/py-polars/tests/unit/constructors/test_constructors.py index e7fe1d4ec9d63..adf5ea28c5677 100644 --- a/py-polars/tests/unit/constructors/test_constructors.py +++ b/py-polars/tests/unit/constructors/test_constructors.py @@ -1,8 +1,7 @@ from __future__ import annotations -import sys from collections import OrderedDict, namedtuple -from datetime import date, datetime, timedelta, timezone +from datetime import date, datetime, time, timedelta, timezone from decimal import Decimal from random import shuffle from typing import TYPE_CHECKING, Any, List, Literal, NamedTuple @@ -14,20 +13,20 @@ from pydantic import BaseModel, Field, TypeAdapter import polars as pl -from polars.dependencies import _ZONEINFO_AVAILABLE, dataclasses, pydantic +from polars._utils.construction.utils import try_get_type_hints +from polars.datatypes import PolarsDataType, numpy_char_code_to_dtype +from polars.dependencies import dataclasses, pydantic from polars.exceptions import TimeZoneAwareConstructorWarning from polars.testing import assert_frame_equal, assert_series_equal -from polars.utils._construction import type_hints if TYPE_CHECKING: - from polars.datatypes import PolarsDataType + from collections.abc import Callable -if sys.version_info >= (3, 9): from zoneinfo import ZoneInfo -elif _ZONEINFO_AVAILABLE: - # Import from submodule due to typing issue with backports.zoneinfo package: - # https://github.com/pganssle/zoneinfo/issues/125 - from backports.zoneinfo._zoneinfo import ZoneInfo + + from polars.datatypes import PolarsDataType +else: + from polars._utils.convert import string_to_zoneinfo as ZoneInfo # ----------------------------------------------------------------------------------- @@ -264,7 +263,7 @@ class TradeNT(NamedTuple): assert df.rows() == raw_data # cover a miscellaneous edge-case when detecting the annotations - assert type_hints(obj=type(None)) == {} + assert try_get_type_hints(obj=type(None)) == {} def test_init_pydantic_2x() -> None: @@ -797,6 +796,45 @@ def test_init_series() -> None: assert_series_equal(s5, pl.Series("", [1, 2, 3], dtype=pl.Int8)) +@pytest.mark.parametrize( + ("dtype", "expected_dtype"), + [ + (int, pl.Int64), + (bytes, pl.Binary), + (float, pl.Float64), + (str, pl.String), + (date, pl.Date), + (time, pl.Time), + (datetime, pl.Datetime("us")), + (timedelta, pl.Duration("us")), + (Decimal, pl.Decimal(precision=None, scale=0)), + ], +) +def test_init_py_dtype(dtype: Any, expected_dtype: PolarsDataType) -> None: + for s in ( + pl.Series("s", [None], dtype=dtype), + pl.Series("s", [], dtype=dtype), + ): + assert s.dtype == expected_dtype + + for df in ( + pl.DataFrame({"col": [None]}, schema={"col": dtype}), + pl.DataFrame({"col": []}, schema={"col": dtype}), + ): + assert df.schema == {"col": expected_dtype} + + +def test_init_py_dtype_misc_float() -> None: + assert pl.Series([100], dtype=float).dtype == pl.Float64 # type: ignore[arg-type] + + df = pl.DataFrame( + {"x": [100.0], "y": [200], "z": [None]}, + schema={"x": float, "y": float, "z": float}, + ) + assert df.schema == {"x": pl.Float64, "y": pl.Float64, "z": pl.Float64} + assert df.rows() == [(100.0, 200.0, None)] + + def test_init_seq_of_seq() -> None: # List of lists df = pl.DataFrame([[1, 2, 3], [4, 5, 6]], schema=["a", "b", "c"]) @@ -1423,7 +1461,7 @@ def test_nested_schema_construction2() -> None: def test_arrow_to_pyseries_with_one_chunk_does_not_copy_data() -> None: - from polars.utils._construction import arrow_to_pyseries + from polars._utils.construction import arrow_to_pyseries original_array = pa.chunked_array([[1, 2, 3]], type=pa.int64()) pyseries = arrow_to_pyseries("", original_array) @@ -1450,11 +1488,9 @@ def test_nested_categorical() -> None: def test_datetime_date_subclasses() -> None: - class FakeDate(date): - ... + class FakeDate(date): ... - class FakeDatetime(FakeDate, datetime): - ... + class FakeDatetime(FakeDate, datetime): ... result = pl.Series([FakeDatetime(2020, 1, 1, 3)]) expected = pl.Series([datetime(2020, 1, 1, 3)]) @@ -1538,3 +1574,24 @@ def test_df_schema_sequences_incorrect_length() -> None: ] with pytest.raises(ValueError): pl.DataFrame(schema=schema) # type: ignore[arg-type] + + +@pytest.mark.parametrize( + ("input", "infer_func", "expected_dtype"), + [ + ("f8", numpy_char_code_to_dtype, pl.Float64), + ("f4", numpy_char_code_to_dtype, pl.Float32), + ("i4", numpy_char_code_to_dtype, pl.Int32), + ("u1", numpy_char_code_to_dtype, pl.UInt8), + ("?", numpy_char_code_to_dtype, pl.Boolean), + ("m8", numpy_char_code_to_dtype, pl.Duration("us")), + ("M8", numpy_char_code_to_dtype, pl.Datetime("us")), + ], +) +def test_numpy_inference( + input: Any, + infer_func: Callable[[Any], PolarsDataType], + expected_dtype: PolarsDataType, +) -> None: + result = infer_func(input) + assert result == expected_dtype diff --git a/py-polars/tests/unit/constructors/test_dataframe.py b/py-polars/tests/unit/constructors/test_dataframe.py new file mode 100644 index 0000000000000..98f2519db9a4f --- /dev/null +++ b/py-polars/tests/unit/constructors/test_dataframe.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +import sys +from typing import Any + +import pytest + +import polars as pl + + +def test_df_mixed_dtypes_string() -> None: + data = {"x": [["abc", 12, 34.5]], "y": [1]} + + with pytest.raises(pl.SchemaError, match="unexpected value"): + pl.DataFrame(data, strict=True) + + df = pl.DataFrame(data, strict=False) + assert df.schema == {"x": pl.List(pl.String), "y": pl.Int64} + assert df.rows() == [(["abc", "12", "34.5"], 1)] + + +def test_df_mixed_dtypes_object() -> None: + data = {"x": [[b"abc", 12, 34.5]], "y": [1]} + # with pytest.raises(pl.SchemaError, match="unexpected value"): + with pytest.raises(pl.ComputeError, match="failed to determine supertype"): + pl.DataFrame(data, strict=True) + + df = pl.DataFrame(data, strict=False) + assert df.schema == {"x": pl.Object, "y": pl.Int64} + assert df.rows() == [([b"abc", 12, 34.5], 1)] + + +def test_df_object() -> None: + class Foo: + def __init__(self, value: int) -> None: + self._value = value + + def __eq__(self, other: Any) -> bool: + return issubclass(other.__class__, self.__class__) and ( + self._value == other._value + ) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self._value})" + + df = pl.DataFrame({"a": [Foo(1), Foo(2)]}) + assert df["a"].dtype == pl.Object + assert df.rows() == [(Foo(1),), (Foo(2),)] + + +def test_df_init_from_generator_dict_view() -> None: + d = {0: "x", 1: "y", 2: "z"} + data = { + "keys": d.keys(), + "vals": d.values(), + "itms": d.items(), + } + with pytest.raises(pl.SchemaError, match="unexpected value"): + pl.DataFrame(data, strict=True) + + df = pl.DataFrame(data, strict=False) + assert df.schema == { + "keys": pl.Int64, + "vals": pl.String, + "itms": pl.List(pl.String), + } + assert df.to_dict(as_series=False) == { + "keys": [0, 1, 2], + "vals": ["x", "y", "z"], + "itms": [["0", "x"], ["1", "y"], ["2", "z"]], + } + + +@pytest.mark.skipif( + sys.version_info < (3, 11), + reason="reversed dict views not supported before Python 3.11", +) +def test_df_init_from_generator_reversed_dict_view() -> None: + d = {0: "x", 1: "y", 2: "z"} + data = { + "rev_keys": reversed(d.keys()), + "rev_vals": reversed(d.values()), + "rev_itms": reversed(d.items()), + } + df = pl.DataFrame(data, schema_overrides={"rev_itms": pl.Object}) + + assert df.schema == { + "rev_keys": pl.Int64, + "rev_vals": pl.String, + "rev_itms": pl.Object, + } + assert df.to_dict(as_series=False) == { + "rev_keys": [2, 1, 0], + "rev_vals": ["z", "y", "x"], + "rev_itms": [(2, "z"), (1, "y"), (0, "x")], + } + + +def test_df_init_strict() -> None: + data = {"a": [1, 2, 3.0]} + schema = {"a": pl.Int8} + with pytest.raises(TypeError): + pl.DataFrame(data, schema=schema, strict=True) + + df = pl.DataFrame(data, schema=schema, strict=False) + + # TODO: This should result in a Float Series without nulls + # https://github.com/pola-rs/polars/issues/14427 + assert df["a"].to_list() == [1, 2, None] + + assert df["a"].dtype == pl.Int8 + + +def test_df_init_from_series_strict() -> None: + s = pl.Series("a", [-1, 0, 1]) + schema = {"a": pl.UInt8} + with pytest.raises(pl.ComputeError): + pl.DataFrame(s, schema=schema, strict=True) + + df = pl.DataFrame(s, schema=schema, strict=False) + + assert df["a"].to_list() == [None, 0, 1] + assert df["a"].dtype == pl.UInt8 diff --git a/py-polars/tests/unit/constructors/test_series.py b/py-polars/tests/unit/constructors/test_series.py new file mode 100644 index 0000000000000..b2285df441693 --- /dev/null +++ b/py-polars/tests/unit/constructors/test_series.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import pytest + +import polars as pl + + +def test_series_mixed_dtypes_list() -> None: + values = [[0.1, 1]] + + with pytest.raises(pl.SchemaError, match="unexpected value"): + pl.Series(values) + + s = pl.Series(values, strict=False) + assert s.dtype == pl.List(pl.Float64) + assert s.to_list() == [[0.1, 1.0]] + + +def test_series_mixed_dtypes_string() -> None: + values = [[12], "foo", 9] + + with pytest.raises(pl.SchemaError, match="unexpected value"): + pl.Series(values) + + s = pl.Series(values, strict=False) + assert s.dtype == pl.String + assert s.to_list() == ["[12]", "foo", "9"] + assert s[1] == "foo" + + +def test_series_mixed_dtypes_object() -> None: + values = [[12], b"foo", 9] + + with pytest.raises(pl.SchemaError, match="unexpected value"): + pl.Series(values) + + s = pl.Series(values, strict=False) + assert s.dtype == pl.Object + assert s.to_list() == values + assert s[1] == b"foo" diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index 06de35bdf2167..7f01fc0ff5710 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -16,6 +16,7 @@ import polars as pl import polars.selectors as cs +from polars._utils.construction import iterable_to_pydf from polars.datatypes import DTYPE_TEMPORAL_UNITS, INTEGER_DTYPES from polars.exceptions import ComputeError, TimeZoneAwareConstructorWarning from polars.testing import ( @@ -24,17 +25,13 @@ assert_series_equal, ) from polars.testing.parametric import columns -from polars.utils._construction import iterable_to_pydf if TYPE_CHECKING: - from polars.type_aliases import JoinStrategy, UniqueKeepStrategy - -if sys.version_info >= (3, 9): from zoneinfo import ZoneInfo + + from polars.type_aliases import JoinStrategy, UniqueKeepStrategy else: - # Import from submodule due to typing issue with backports.zoneinfo package: - # https://github.com/pganssle/zoneinfo/issues/125 - from backports.zoneinfo._zoneinfo import ZoneInfo + from polars._utils.convert import string_to_zoneinfo as ZoneInfo def test_version() -> None: @@ -258,6 +255,15 @@ def test_from_arrow(monkeypatch: Any) -> None: assert df.schema == expected_schema assert df.rows() == expected_data + # record batches (inc. empty) + for b, n_expected in ( + (record_batches[0], 1), + (record_batches[0][:0], 0), + ): + df = cast(pl.DataFrame, pl.from_arrow(b)) + assert df.schema == expected_schema + assert df.rows() == expected_data[:n_expected] + empty_tbl = tbl[:0] # no rows df = cast(pl.DataFrame, pl.from_arrow(empty_tbl)) assert df.schema == expected_schema @@ -1049,32 +1055,6 @@ def test_literal_series() -> None: ) -def test_to_html() -> None: - # check it does not panic/error, and appears to contain - # a reasonable table with suitably escaped html entities. - df = pl.DataFrame( - { - "foo": [1, 2, 3], - "": ["a", "b", "c"], - "": ["a", "b", "c"], - } - ) - html = df._repr_html_() - for match in ( - "foo", - "<bar>", - "<baz", - "spam>", - "1", - "2", - "3", - ): - assert match in html, f"Expected to find {match!r} in html repr" - - def test_rename(df: pl.DataFrame) -> None: out = df.rename({"strings": "bars", "int": "foos"}) # check if we can select these new columns @@ -1188,34 +1168,6 @@ def __iter__(self) -> Iterator[Any]: pl.DataFrame(schema=["a", "b", "c", "d"]), ) - # dict-related generator-views - d = {0: "x", 1: "y", 2: "z"} - df = pl.DataFrame( - { - "keys": d.keys(), - "vals": d.values(), - "itms": d.items(), - } - ) - assert df.to_dict(as_series=False) == { - "keys": [0, 1, 2], - "vals": ["x", "y", "z"], - "itms": [(0, "x"), (1, "y"), (2, "z")], - } - if sys.version_info >= (3, 11): - df = pl.DataFrame( - { - "rev_keys": reversed(d.keys()), - "rev_vals": reversed(d.values()), - "rev_itms": reversed(d.items()), - } - ) - assert df.to_dict(as_series=False) == { - "rev_keys": [2, 1, 0], - "rev_vals": ["z", "y", "x"], - "rev_itms": [(2, "z"), (1, "y"), (0, "x")], - } - def test_from_rows() -> None: df = pl.from_records([[1, 2, "foo"], [2, 3, "bar"]]) @@ -1478,7 +1430,7 @@ def test_reproducible_hash_with_seeds() -> None: if platform.mac_ver()[-1] != "arm64": expected = pl.Series( "s", - [15801072432137883943, 6344663067812082469, 9604537446374444741], + [8661293245726181094, 9565952849861441858, 2921274555702885622], dtype=pl.UInt64, ) result = df.hash_rows(*seeds) @@ -1489,30 +1441,6 @@ def test_reproducible_hash_with_seeds() -> None: assert_series_equal(expected, result, check_names=False, check_exact=True) -def test_create_df_from_object() -> None: - class Foo: - def __init__(self, value: int) -> None: - self._value = value - - def __eq__(self, other: Any) -> bool: - return issubclass(other.__class__, self.__class__) and ( - self._value == other._value - ) - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self._value})" - - # from miscellaneous object - df = pl.DataFrame({"a": [Foo(1), Foo(2)]}) - assert df["a"].dtype == pl.Object - assert df.rows() == [(Foo(1),), (Foo(2),)] - - # from mixed-type input - df = pl.DataFrame({"x": [["abc", 12, 34.5]], "y": [1]}) - assert df.schema == {"x": pl.Object, "y": pl.Int64} - assert df.rows() == [(["abc", 12, 34.5], 1)] - - def test_hashing_on_python_objects() -> None: # see if we can do a group_by, drop_duplicates on a DataFrame with objects. # this requires that the hashing and aggregations are done on python objects @@ -2810,7 +2738,7 @@ def test_init_datetimes_with_timezone() -> None: tz_europe = "Europe/Amsterdam" dtm = datetime(2022, 10, 12, 12, 30) - for time_unit in DTYPE_TEMPORAL_UNITS | frozenset([None]): + for time_unit in DTYPE_TEMPORAL_UNITS: for type_overrides in ( { "schema": [ @@ -2912,7 +2840,7 @@ def test_init_physical_with_timezone() -> None: tz_asia = "Asia/Tokyo" dtm_us = 1665577800000000 - for time_unit in DTYPE_TEMPORAL_UNITS | frozenset([None]): + for time_unit in DTYPE_TEMPORAL_UNITS: dtm = {"ms": dtm_us // 1_000, "ns": dtm_us * 1_000}.get(str(time_unit), dtm_us) df = pl.DataFrame( data={"d1": [dtm], "d2": [dtm]}, diff --git a/py-polars/tests/unit/dataframe/test_repr_html.py b/py-polars/tests/unit/dataframe/test_repr_html.py new file mode 100644 index 0000000000000..8e7a62a6efc25 --- /dev/null +++ b/py-polars/tests/unit/dataframe/test_repr_html.py @@ -0,0 +1,79 @@ +import polars as pl + + +def test_repr_html() -> None: + # check it does not panic/error, and appears to contain + # a reasonable table with suitably escaped html entities. + df = pl.DataFrame( + { + "foo": [1, 2, 3], + "": ["a", "b", "c"], + "": ["a", "b", "c"], + } + ) + html = df._repr_html_() + for match in ( + "foo", + "<bar>", + "<baz", + "spam>", + "1", + "2", + "3", + ): + assert match in html, f"Expected to find {match!r} in html repr" + + +def test_html_tables() -> None: + df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) + + # default: header contains names/dtypes + header = "abci64i64i64" + assert header in df._repr_html_() + + # validate that relevant config options are respected + with pl.Config(tbl_hide_column_names=True): + header = "i64i64i64" + assert header in df._repr_html_() + + with pl.Config(tbl_hide_column_data_types=True): + header = "abc" + assert header in df._repr_html_() + + with pl.Config( + tbl_hide_column_data_types=True, + tbl_hide_column_names=True, + ): + header = "" + assert header in df._repr_html_() + + +def test_df_repr_html_max_rows_default() -> None: + df = pl.DataFrame({"a": range(50)}) + + html = df._repr_html_() + + expected_rows = 10 + assert html.count("") - 2 == expected_rows + + +def test_df_repr_html_max_rows_odd() -> None: + df = pl.DataFrame({"a": range(50)}) + + with pl.Config(tbl_rows=9): + html = df._repr_html_() + + expected_rows = 9 + assert html.count("") - 2 == expected_rows + + +def test_series_repr_html_max_rows_default() -> None: + s = pl.Series("a", range(50)) + + html = s._repr_html_() + + expected_rows = 10 + assert html.count("") - 2 == expected_rows diff --git a/py-polars/tests/unit/datatypes/test_binary.py b/py-polars/tests/unit/datatypes/test_binary.py index 2e3e198666c04..189786dfbc2ab 100644 --- a/py-polars/tests/unit/datatypes/test_binary.py +++ b/py-polars/tests/unit/datatypes/test_binary.py @@ -16,7 +16,7 @@ def test_binary_filter() -> None: def test_binary_to_list() -> None: - data = {"binary": [b"\xFD\x00\xFE\x00\xFF\x00", b"\x10\x00\x20\x00\x30\x00"]} + data = {"binary": [b"\xfd\x00\xfe\x00\xff\x00", b"\x10\x00\x20\x00\x30\x00"]} schema = {"binary": pl.Binary} print(pl.DataFrame(data, schema)) diff --git a/py-polars/tests/unit/datatypes/test_datatype.py b/py-polars/tests/unit/datatypes/test_datatype.py new file mode 100644 index 0000000000000..0fe164ded8c47 --- /dev/null +++ b/py-polars/tests/unit/datatypes/test_datatype.py @@ -0,0 +1,11 @@ +import copy + +import polars as pl + + +# https://github.com/pola-rs/polars/issues/14771 +def test_datatype_copy() -> None: + dtype = pl.Int64() + result = copy.deepcopy(dtype) + assert dtype == dtype + assert isinstance(result, pl.Int64) diff --git a/py-polars/tests/unit/datatypes/test_decimal.py b/py-polars/tests/unit/datatypes/test_decimal.py index 1c125de7a2ebf..de08c311255f0 100644 --- a/py-polars/tests/unit/datatypes/test_decimal.py +++ b/py-polars/tests/unit/datatypes/test_decimal.py @@ -293,6 +293,11 @@ def test_decimal_aggregations() -> None: } ) + assert df.group_by("g").agg("a").sort("g").to_dict(as_series=False) == { + "g": [1, 2], + "a": [[D("0.1"), D("10.1")], [D("100.01"), D("9000.12")]], + } + assert df.group_by("g", maintain_order=True).agg( sum=pl.sum("a"), min=pl.min("a"), @@ -315,6 +320,21 @@ def test_decimal_aggregations() -> None: } +def test_decimal_df_vertical_sum() -> None: + df = pl.DataFrame({"a": [D("1.1"), D("2.2")]}) + expected = pl.DataFrame({"a": [D("3.3")]}) + + assert_frame_equal(df.sum(), expected) + + +def test_decimal_df_vertical_agg() -> None: + df = pl.DataFrame({"a": [D("1.0"), D("2.0"), D("3.0")]}) + expected_min = pl.DataFrame({"a": [D("1.0")]}) + expected_max = pl.DataFrame({"a": [D("3.0")]}) + assert_frame_equal(df.min(), expected_min) + assert_frame_equal(df.max(), expected_max) + + def test_decimal_in_filter() -> None: df = pl.DataFrame( { @@ -329,6 +349,46 @@ def test_decimal_in_filter() -> None: } +def test_decimal_sort() -> None: + df = pl.DataFrame( + { + "foo": [1, 2, 3], + "bar": [D("3.4"), D("2.1"), D("4.5")], + "baz": [1, 1, 2], + } + ) + assert df.sort("bar").to_dict(as_series=False) == { + "foo": [2, 1, 3], + "bar": [D("2.1"), D("3.4"), D("4.5")], + "baz": [1, 1, 2], + } + assert df.sort(["foo", "bar"]).to_dict(as_series=False) == { + "foo": [1, 2, 3], + "bar": [D("3.4"), D("2.1"), D("4.5")], + "baz": [1, 1, 2], + } + + assert df.select([pl.col("foo").sort_by("bar", descending=True).alias("s1")])[ + "s1" + ].to_list() == [3, 1, 2] + assert df.select([pl.col("foo").sort_by(["baz", "bar"]).alias("s2")])[ + "s2" + ].to_list() == [2, 1, 3] + + +def test_decimal_unique() -> None: + df = pl.DataFrame( + { + "foo": [1, 1, 2], + "bar": [D("3.4"), D("3.4"), D("4.5")], + } + ) + assert df.unique().sort("bar").to_dict(as_series=False) == { + "foo": [1, 2], + "bar": [D("3.4"), D("4.5")], + } + + def test_decimal_write_parquet_12375() -> None: f = io.BytesIO() df = pl.DataFrame( diff --git a/py-polars/tests/unit/datatypes/test_enum.py b/py-polars/tests/unit/datatypes/test_enum.py index 8afbee7d9b789..5562f703e2e9e 100644 --- a/py-polars/tests/unit/datatypes/test_enum.py +++ b/py-polars/tests/unit/datatypes/test_enum.py @@ -1,6 +1,7 @@ from __future__ import annotations import operator +import re from datetime import date from textwrap import dedent from typing import Any, Callable @@ -9,7 +10,7 @@ import polars as pl from polars import StringCache -from polars.testing import assert_series_equal +from polars.testing import assert_frame_equal, assert_series_equal def test_enum_creation() -> None: @@ -37,11 +38,21 @@ def test_enum_init_empty(categories: pl.Series | list[str] | None) -> None: def test_enum_non_existent() -> None: with pytest.raises( pl.ComputeError, - match=("value 'c' is not present in Enum"), + match=re.escape( + "conversion from `str` to `enum` failed in column '' for 1 out of 4 values: [\"c\"]" + ), ): pl.Series([None, "a", "b", "c"], dtype=pl.Enum(categories=["a", "b"])) +def test_enum_non_existent_non_strict() -> None: + s = pl.Series( + [None, "a", "b", "c"], dtype=pl.Enum(categories=["a", "b"]), strict=False + ) + expected = pl.Series([None, "a", "b", None], dtype=pl.Enum(categories=["a", "b"])) + assert_series_equal(s, expected) + + def test_enum_from_schema_argument() -> None: df = pl.DataFrame( {"col1": ["a", "b", "c"]}, schema={"col1": pl.Enum(["a", "b", "c"])} @@ -106,6 +117,26 @@ def test_casting_to_an_enum_from_categorical() -> None: assert_series_equal(s2, expected) +def test_casting_to_an_enum_from_categorical_nonstrict() -> None: + dtype = pl.Enum(["a", "b"]) + s = pl.Series([None, "a", "b", "c"], dtype=pl.Categorical) + s2 = s.cast(dtype, strict=False) + assert s2.dtype == dtype + assert s2.null_count() == 2 # "c" mapped to null + expected = pl.Series([None, "a", "b", None], dtype=dtype) + assert_series_equal(s2, expected) + + +def test_casting_to_an_enum_from_enum_nonstrict() -> None: + dtype = pl.Enum(["a", "b"]) + s = pl.Series([None, "a", "b", "c"], dtype=pl.Enum(["a", "b", "c"])) + s2 = s.cast(dtype, strict=False) + assert s2.dtype == dtype + assert s2.null_count() == 2 # "c" mapped to null + expected = pl.Series([None, "a", "b", None], dtype=dtype) + assert_series_equal(s2, expected) + + def test_casting_to_an_enum_from_integer() -> None: dtype = pl.Enum(["a", "b", "c"]) expected = pl.Series([None, "b", "a", "c"], dtype=dtype) @@ -128,7 +159,9 @@ def test_casting_to_an_enum_oob_from_integer() -> None: def test_casting_to_an_enum_from_categorical_nonexistent() -> None: with pytest.raises( pl.ComputeError, - match=("value 'c' is not present in Enum"), + match=( + r"conversion from `cat` to `enum` failed in column '' for 1 out of 4 values: \[\"c\"\]" + ), ): pl.Series([None, "a", "b", "c"], dtype=pl.Categorical).cast(pl.Enum(["a", "b"])) @@ -148,7 +181,9 @@ def test_casting_to_an_enum_from_global_categorical() -> None: def test_casting_to_an_enum_from_global_categorical_nonexistent() -> None: with pytest.raises( pl.ComputeError, - match=("value 'c' is not present in Enum"), + match=( + r"conversion from `cat` to `enum` failed in column '' for 1 out of 4 values: \[\"c\"\]" + ), ): pl.Series([None, "a", "b", "c"], dtype=pl.Categorical).cast(pl.Enum(["a", "b"])) @@ -305,7 +340,10 @@ def test_compare_enum_str_single_raise( s2 = "NOTEXIST" with pytest.raises( - pl.ComputeError, match="value 'NOTEXIST' is not present in Enum" + pl.ComputeError, + match=re.escape( + "conversion from `str` to `enum` failed in column '' for 1 out of 1 values: [\"NOTEXIST\"]" + ), ): op(s, s2) # type: ignore[arg-type] @@ -318,7 +356,8 @@ def test_compare_enum_str_raise() -> None: for s_compare in [s2, s_broadcast]: for op in [operator.le, operator.gt, operator.ge, operator.lt]: with pytest.raises( - pl.ComputeError, match="value 'd' is not present in Enum" + pl.ComputeError, + match="conversion from `str` to `enum` failed in column", ): op(s, s_compare) @@ -364,7 +403,7 @@ def test_enum_categories_unique() -> None: def test_enum_categories_series_input() -> None: - categories = pl.Series("a", ["x", "y", "z"]) + categories = pl.Series("a", ["a", "b", "c"]) dtype = pl.Enum(categories) assert_series_equal(dtype.categories, categories.alias("category")) @@ -402,3 +441,35 @@ def test_enum_cast_from_other_integer_dtype_oob() -> None: pl.ComputeError, match="conversion from `u64` to `u32` failed in column" ): series.cast(enum_dtype) + + +def test_enum_creating_col_expr() -> None: + df = pl.DataFrame( + { + "col1": ["a", "b", "c"], + "col2": ["d", "e", "f"], + "col3": ["g", "h", "i"], + }, + schema={ + "col1": pl.Enum(["a", "b", "c"]), + "col2": pl.Categorical(), + "col3": pl.Enum(["g", "h", "i"]), + }, + ) + + out = df.select(pl.col(pl.Enum)) + expected = df.select("col1", "col3") + assert_frame_equal(out, expected) + + +def test_enum_cse_eq() -> None: + df = pl.DataFrame({"a": [1]}) + + # these both share the value "a", which is used in both expressions + dt1 = pl.Enum(["a", "b"]) + dt2 = pl.Enum(["a", "c"]) + + df.lazy().select( + pl.when(True).then(pl.lit("a", dtype=dt1)).alias("dt1"), + pl.when(True).then(pl.lit("a", dtype=dt2)).alias("dt2"), + ).collect() diff --git a/py-polars/tests/unit/datatypes/test_float.py b/py-polars/tests/unit/datatypes/test_float.py index 62c975b262987..5a318fd0015cc 100644 --- a/py-polars/tests/unit/datatypes/test_float.py +++ b/py-polars/tests/unit/datatypes/test_float.py @@ -1,4 +1,7 @@ +import pytest + import polars as pl +from polars.testing import assert_series_equal def test_nan_in_group_by_agg() -> None: @@ -32,3 +35,264 @@ def test_nan_aggregations() -> None: str(df.group_by("b").agg(aggs).to_dict(as_series=False)) == "{'b': [1], 'max': [3.0], 'min': [1.0], 'nan_max': [nan], 'nan_min': [nan]}" ) + + +@pytest.mark.parametrize("descending", [True, False]) +def test_sorted_nan_max_12931(descending: bool) -> None: + s = pl.Series("x", [1.0, 2.0, float("nan")]).sort(descending=descending) + + assert s.max() == 2.0 + assert s.arg_max() == 1 + + # Test full-nan + s = pl.Series("x", [float("nan"), float("nan"), float("nan")]).sort( + descending=descending + ) + + out = s.max() + assert out != out + + # This is flipped because float arg_max calculates the index as + # * sorted ascending: (index of left-most NaN) - 1, saturating subtraction at 0 + # * sorted descending: (index of right-most NaN) + 1, saturating addition at s.len() + assert s.arg_max() == (0, 2)[descending] + + s = pl.Series("x", [1.0, 2.0, 3.0]).sort(descending=descending) + + assert s.max() == 3.0 + assert s.arg_max() == (2, 0)[descending] + + +@pytest.mark.parametrize( + ("s", "expect"), + [ + ( + pl.Series( + "x", + [ + -0.0, + 0.0, + float("-nan"), + float("nan"), + 1.0, + None, + ], + ), + pl.Series("x", [None, 0.0, 1.0, float("nan")]), + ), + ( + # No nulls + pl.Series( + "x", + [ + -0.0, + 0.0, + float("-nan"), + float("nan"), + 1.0, + ], + ), + pl.Series("x", [0.0, 1.0, float("nan")]), + ), + ], +) +def test_unique(s: pl.Series, expect: pl.Series) -> None: + out = s.unique() + assert_series_equal(expect, out) + + out = s.n_unique() # type: ignore[assignment] + assert expect.len() == out + + out = s.gather(s.arg_unique()).sort() + assert_series_equal(expect, out) + + +def test_unique_counts() -> None: + s = pl.Series( + "x", + [ + -0.0, + 0.0, + float("-nan"), + float("nan"), + 1.0, + None, + ], + ) + expect = pl.Series("x", [2, 2, 1, 1], dtype=pl.UInt32) + out = s.unique_counts() + assert_series_equal(expect, out) + + +def test_hash() -> None: + s = pl.Series( + "x", + [ + -0.0, + 0.0, + float("-nan"), + float("nan"), + 1.0, + None, + ], + ).hash() + + # check them against each other since hash is not stable + assert s.item(0) == s.item(1) # hash(-0.0) == hash(0.0) + assert s.item(2) == s.item(3) # hash(float('-nan')) == hash(float('nan')) + + +def test_group_by() -> None: + # Test num_groups_proxy + # * -0.0 and 0.0 in same groups + # * -nan and nan in same groups + df = ( + pl.Series( + "x", + [ + -0.0, + 0.0, + float("-nan"), + float("nan"), + 1.0, + None, + ], + ) + .to_frame() + .with_row_index() + .with_columns(a=pl.lit("a")) + ) + + expect = pl.Series("index", [[0, 1], [2, 3], [4], [5]], dtype=pl.List(pl.UInt32)) + expect_no_null = expect.head(3) + + for group_keys in (("x",), ("x", "a")): + for maintain_order in (True, False): + for drop_nulls in (True, False): + out = df + if drop_nulls: + out = out.drop_nulls() + + out = ( + out.group_by(group_keys, maintain_order=maintain_order) # type: ignore[assignment] + .agg("index") + .sort(pl.col("index").list.get(0)) + .select("index") + .to_series() + ) + + if drop_nulls: + assert_series_equal(expect_no_null, out) # type: ignore[arg-type] + else: + assert_series_equal(expect, out) # type: ignore[arg-type] + + +def test_joins() -> None: + # Test that -0.0 joins with 0.0 and nan joins with nan + df = ( + pl.Series( + "x", + [ + -0.0, + 0.0, + float("-nan"), + float("nan"), + 1.0, + None, + ], + ) + .to_frame() + .with_row_index() + .with_columns(a=pl.lit("a")) + ) + + rhs = ( + pl.Series("x", [0.0, float("nan"), 3.0]) + .to_frame() + .with_columns(a=pl.lit("a"), rhs=True) + ) + + for join_on in ( + # Single and multiple keys + ("x",), + ( + "x", + "a", + ), + ): + how = "left" + expect = pl.Series("rhs", [True, True, True, True, None, None]) + out = df.join(rhs, on=join_on, how=how).sort("index").select("rhs").to_series() # type: ignore[arg-type] + assert_series_equal(expect, out) + + how = "inner" + expect = pl.Series("index", [0, 1, 2, 3], dtype=pl.UInt32) + out = ( + df.join(rhs, on=join_on, how=how).sort("index").select("index").to_series() # type: ignore[arg-type] + ) + assert_series_equal(expect, out) + + how = "outer" + expect = pl.Series("rhs", [True, True, True, True, None, None, True]) + out = ( + df.join(rhs, on=join_on, how=how) # type: ignore[arg-type] + .sort("index", nulls_last=True) + .select("rhs") + .to_series() + ) + assert_series_equal(expect, out) + + how = "semi" + expect = pl.Series("x", [-0.0, 0.0, float("-nan"), float("nan")]) + out = ( + df.join(rhs, on=join_on, how=how) # type: ignore[arg-type] + .sort("index", nulls_last=True) + .select("x") + .to_series() + ) + assert_series_equal(expect, out) + + how = "anti" + expect = pl.Series("x", [1.0, None]) + out = ( + df.join(rhs, on=join_on, how=how) # type: ignore[arg-type] + .sort("index", nulls_last=True) + .select("x") + .to_series() + ) + assert_series_equal(expect, out) + + # test asof + # note that nans never join because nans are always greater than the other + # side of the comparison (i.e. NaN > tolerance) + expect = pl.Series("rhs", [True, True, None, None, None, None]) + out = ( + df.sort("x") + .join_asof(rhs.sort("x"), on="x", tolerance=0) + .sort("index") + .select("rhs") + .to_series() + ) + assert_series_equal(expect, out) + + +def test_first_last_distinct() -> None: + s = pl.Series( + "x", + [ + -0.0, + 0.0, + float("-nan"), + float("nan"), + 1.0, + None, + ], + ) + + assert_series_equal( + pl.Series("x", [True, False, True, False, True, True]), s.is_first_distinct() + ) + + assert_series_equal( + pl.Series("x", [False, True, False, True, True, True]), s.is_last_distinct() + ) diff --git a/py-polars/tests/unit/datatypes/test_list.py b/py-polars/tests/unit/datatypes/test_list.py index f439781b44221..ba580ea8e6ba0 100644 --- a/py-polars/tests/unit/datatypes/test_list.py +++ b/py-polars/tests/unit/datatypes/test_list.py @@ -23,15 +23,20 @@ def test_dtype() -> None: assert a.dtype.is_(pl.List(pl.Int64)) # explicit + u64_max = (2**64) - 1 df = pl.DataFrame( data={ "i": [[1, 2, 3]], + "li": [[[1, 2, 3]]], + "u": [[u64_max]], "tm": [[time(10, 30, 45)]], "dt": [[date(2022, 12, 31)]], "dtm": [[datetime(2022, 12, 31, 1, 2, 3)]], }, schema=[ ("i", pl.List(pl.Int8)), + ("li", pl.List(pl.List(pl.Int8))), + ("u", pl.List(pl.UInt64)), ("tm", pl.List(pl.Time)), ("dt", pl.List(pl.Date)), ("dtm", pl.List(pl.Datetime)), @@ -39,6 +44,8 @@ def test_dtype() -> None: ) assert df.schema == { "i": pl.List(pl.Int8), + "li": pl.List(pl.List(pl.Int8)), + "u": pl.List(pl.UInt64), "tm": pl.List(pl.Time), "dt": pl.List(pl.Date), "dtm": pl.List(pl.Datetime), @@ -48,6 +55,8 @@ def test_dtype() -> None: assert df.rows() == [ ( [1, 2, 3], + [[1, 2, 3]], + [u64_max], [time(10, 30, 45)], [date(2022, 12, 31)], [datetime(2022, 12, 31, 1, 2, 3)], @@ -752,3 +761,32 @@ def test_list_median(data_dispersion: pl.DataFrame) -> None: ) assert_frame_equal(result, expected) + + +def test_list_gather_null_struct_14927() -> None: + df = pl.DataFrame( + [ + { + "index": 0, + "col_0": [{"field_0": 1.0}], + }, + { + "index": 1, + "col_0": None, + }, + ] + ) + + expected = pl.DataFrame( + {"index": [1], "col_0": [None], "field_0": [None]}, + schema={**df.schema, "field_0": pl.Float64}, + ) + expr = pl.col("col_0").list.get(0).struct.field("field_0") + out = df.filter(pl.col("index") > 0).with_columns(expr) + assert_frame_equal(out, expected) + + +def test_list_of_series_with_nulls() -> None: + inner_series = pl.Series("inner", [1, 2, 3]) + s = pl.Series("a", [inner_series, None]) + assert_series_equal(s, pl.Series("a", [[1, 2, 3], None])) diff --git a/py-polars/tests/unit/datatypes/test_temporal.py b/py-polars/tests/unit/datatypes/test_temporal.py index a557cd0397866..4f6b2b18f32be 100644 --- a/py-polars/tests/unit/datatypes/test_temporal.py +++ b/py-polars/tests/unit/datatypes/test_temporal.py @@ -12,7 +12,12 @@ import polars as pl from polars.datatypes import DATETIME_DTYPES, DTYPE_TEMPORAL_UNITS, TEMPORAL_DTYPES -from polars.exceptions import ComputeError, TimeZoneAwareConstructorWarning +from polars.exceptions import ( + ComputeError, + InvalidOperationError, + PolarsInefficientMapWarning, + TimeZoneAwareConstructorWarning, +) from polars.testing import ( assert_frame_equal, assert_series_equal, @@ -24,7 +29,7 @@ from polars.type_aliases import Ambiguous, PolarsTemporalType, TimeUnit else: - from polars.utils.convert import get_zoneinfo as ZoneInfo + from polars._utils.convert import string_to_zoneinfo as ZoneInfo def test_fill_null() -> None: @@ -947,45 +952,50 @@ def test_temporal_dtypes_map_elements( ) const_dtm = datetime(2010, 9, 12) - assert_frame_equal( - df.with_columns( - [ - # don't actually do any of this; native expressions are MUCH faster ;) - pl.col("timestamp") - .map_elements(lambda x: const_dtm, skip_nulls=skip_nulls) - .alias("const_dtm"), - pl.col("timestamp") - .map_elements(lambda x: x and x.date(), skip_nulls=skip_nulls) - .alias("date"), - pl.col("timestamp") - .map_elements(lambda x: x and x.time(), skip_nulls=skip_nulls) - .alias("time"), - ] - ), - pl.DataFrame( - [ - ( - datetime(2010, 9, 12, 10, 19, 54), - datetime(2010, 9, 12, 0, 0), - date(2010, 9, 12), - time(10, 19, 54), - ), - (None, expected_value, None, None), - ( - datetime(2009, 2, 13, 23, 31, 30), - datetime(2010, 9, 12, 0, 0), - date(2009, 2, 13), - time(23, 31, 30), - ), - ], - schema={ - "timestamp": pl.Datetime("ms"), - "const_dtm": pl.Datetime("us"), - "date": pl.Date, - "time": pl.Time, - }, - ), - ) + with pytest.warns( + PolarsInefficientMapWarning, + match=r"(?s)Replace this expression.*lambda x:", + ): + assert_frame_equal( + df.with_columns( + [ + # don't actually do this; native expressions are MUCH faster ;) + pl.col("timestamp") + .map_elements(lambda x: const_dtm, skip_nulls=skip_nulls) + .alias("const_dtm"), + # note: the below now trigger a PolarsInefficientMapWarning + pl.col("timestamp") + .map_elements(lambda x: x and x.date(), skip_nulls=skip_nulls) + .alias("date"), + pl.col("timestamp") + .map_elements(lambda x: x and x.time(), skip_nulls=skip_nulls) + .alias("time"), + ] + ), + pl.DataFrame( + [ + ( + datetime(2010, 9, 12, 10, 19, 54), + datetime(2010, 9, 12, 0, 0), + date(2010, 9, 12), + time(10, 19, 54), + ), + (None, expected_value, None, None), + ( + datetime(2009, 2, 13, 23, 31, 30), + datetime(2010, 9, 12, 0, 0), + date(2009, 2, 13), + time(23, 31, 30), + ), + ], + schema={ + "timestamp": pl.Datetime("ms"), + "const_dtm": pl.Datetime("us"), + "date": pl.Date, + "time": pl.Time, + }, + ), + ) def test_timelike_init() -> None: @@ -1397,7 +1407,7 @@ def test_replace_time_zone() -> None: @pytest.mark.parametrize( ("to_tz", "tzinfo"), [ - ("America/Barbados", ZoneInfo(key="America/Barbados")), + ("America/Barbados", ZoneInfo("America/Barbados")), (None, None), ], ) @@ -1421,7 +1431,7 @@ def test_strptime_with_tz() -> None: .str.strptime(pl.Datetime("us", "Africa/Monrovia")) .item() ) - assert result == datetime(2020, 1, 1, 3, tzinfo=ZoneInfo(key="Africa/Monrovia")) + assert result == datetime(2020, 1, 1, 3, tzinfo=ZoneInfo("Africa/Monrovia")) @pytest.mark.parametrize( @@ -1487,7 +1497,7 @@ def test_convert_time_zone_lazy_schema() -> None: def test_convert_time_zone_on_tz_naive() -> None: ts = pl.Series(["2020-01-01"]).str.strptime(pl.Datetime) result = ts.dt.convert_time_zone("Asia/Kathmandu").item() - expected = datetime(2020, 1, 1, 5, 45, tzinfo=ZoneInfo(key="Asia/Kathmandu")) + expected = datetime(2020, 1, 1, 5, 45, tzinfo=ZoneInfo("Asia/Kathmandu")) assert result == expected result = ( ts.dt.replace_time_zone("UTC").dt.convert_time_zone("Asia/Kathmandu").item() @@ -1573,8 +1583,8 @@ def test_replace_time_zone_from_naive() -> None: pl.col("date").cast(pl.Datetime).dt.replace_time_zone("America/New_York") ).to_dict(as_series=False) == { "date": [ - datetime(2022, 1, 1, 0, 0, tzinfo=ZoneInfo(key="America/New_York")), - datetime(2022, 1, 2, 0, 0, tzinfo=ZoneInfo(key="America/New_York")), + datetime(2022, 1, 1, 0, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(2022, 1, 2, 0, 0, tzinfo=ZoneInfo("America/New_York")), ] } @@ -1661,6 +1671,48 @@ def test_replace_time_zone_sortedness_expressions( assert result["ts"].flags["SORTED_ASC"] == expected_sortedness +def test_invalid_ambiguous_value_in_expression() -> None: + df = pl.DataFrame( + {"a": [datetime(2020, 10, 25, 1)] * 2, "b": ["earliest", "cabbage"]} + ) + with pytest.raises(InvalidOperationError, match="Invalid argument cabbage"): + df.select( + pl.col("a").dt.replace_time_zone("Europe/London", ambiguous=pl.col("b")) + ) + + +def test_replace_time_zone_ambiguous_null() -> None: + df = pl.DataFrame( + { + "a": [datetime(2020, 10, 25, 1)] * 3 + [None], + "b": ["earliest", "latest", "null", "raise"], + } + ) + # expression containing 'null' + result = df.select( + pl.col("a").dt.replace_time_zone("Europe/London", ambiguous=pl.col("b")) + )["a"] + expected = [ + datetime(2020, 10, 25, 1, fold=0, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 1, fold=1, tzinfo=ZoneInfo("Europe/London")), + None, + None, + ] + assert result[0] == expected[0] + assert result[1] == expected[1] + assert result[2] == expected[2] + assert result[3] == expected[3] + + # single 'null' value + result = df.select( + pl.col("a").dt.replace_time_zone("Europe/London", ambiguous="null") + )["a"] + assert result[0] is None + assert result[1] is None + assert result[2] is None + assert result[3] is None + + def test_use_earliest_deprecation() -> None: # strptime with pytest.warns( @@ -1822,6 +1874,19 @@ def test_ambiguous_expressions() -> None: assert_series_equal(result, expected) +def test_single_ambiguous_null() -> None: + df = pl.DataFrame( + {"ts": [datetime(2020, 10, 2, 1, 1)], "ambiguous": [None]}, + schema_overrides={"ambiguous": pl.String}, + ) + result = df.select( + pl.col("ts").dt.replace_time_zone( + "Europe/London", ambiguous=pl.col("ambiguous") + ) + )["ts"].item() + assert result is None + + def test_unlocalize() -> None: tz_naive = pl.Series(["2020-01-01 03:00:00"]).str.strptime(pl.Datetime) tz_aware = tz_naive.dt.replace_time_zone("UTC").dt.convert_time_zone( @@ -1845,22 +1910,22 @@ def test_tz_aware_truncate() -> None: result = df.with_columns(pl.col("dt").dt.truncate("1d").alias("trunced")) expected = { "dt": [ - datetime(2022, 11, 1, 0, 0, tzinfo=ZoneInfo(key="America/New_York")), - datetime(2022, 11, 1, 12, 0, tzinfo=ZoneInfo(key="America/New_York")), - datetime(2022, 11, 2, 0, 0, tzinfo=ZoneInfo(key="America/New_York")), - datetime(2022, 11, 2, 12, 0, tzinfo=ZoneInfo(key="America/New_York")), - datetime(2022, 11, 3, 0, 0, tzinfo=ZoneInfo(key="America/New_York")), - datetime(2022, 11, 3, 12, 0, tzinfo=ZoneInfo(key="America/New_York")), - datetime(2022, 11, 4, 0, 0, tzinfo=ZoneInfo(key="America/New_York")), + datetime(2022, 11, 1, 0, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(2022, 11, 1, 12, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(2022, 11, 2, 0, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(2022, 11, 2, 12, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(2022, 11, 3, 0, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(2022, 11, 3, 12, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(2022, 11, 4, 0, 0, tzinfo=ZoneInfo("America/New_York")), ], "trunced": [ - datetime(2022, 11, 1, 0, 0, tzinfo=ZoneInfo(key="America/New_York")), - datetime(2022, 11, 1, 0, 0, tzinfo=ZoneInfo(key="America/New_York")), - datetime(2022, 11, 2, 0, 0, tzinfo=ZoneInfo(key="America/New_York")), - datetime(2022, 11, 2, 0, 0, tzinfo=ZoneInfo(key="America/New_York")), - datetime(2022, 11, 3, 0, 0, tzinfo=ZoneInfo(key="America/New_York")), - datetime(2022, 11, 3, 0, 0, tzinfo=ZoneInfo(key="America/New_York")), - datetime(2022, 11, 4, 0, 0, tzinfo=ZoneInfo(key="America/New_York")), + datetime(2022, 11, 1, 0, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(2022, 11, 1, 0, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(2022, 11, 2, 0, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(2022, 11, 2, 0, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(2022, 11, 3, 0, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(2022, 11, 3, 0, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(2022, 11, 4, 0, 0, tzinfo=ZoneInfo("America/New_York")), ], } assert result.to_dict(as_series=False) == expected @@ -1891,34 +1956,34 @@ def test_tz_aware_truncate() -> None: datetime(2022, 1, 1, 6, 0), ], "UTC": [ - datetime(2021, 12, 31, 23, 0, tzinfo=ZoneInfo(key="UTC")), - datetime(2022, 1, 1, 0, 0, tzinfo=ZoneInfo(key="UTC")), - datetime(2022, 1, 1, 1, 0, tzinfo=ZoneInfo(key="UTC")), - datetime(2022, 1, 1, 2, 0, tzinfo=ZoneInfo(key="UTC")), - datetime(2022, 1, 1, 3, 0, tzinfo=ZoneInfo(key="UTC")), - datetime(2022, 1, 1, 4, 0, tzinfo=ZoneInfo(key="UTC")), - datetime(2022, 1, 1, 5, 0, tzinfo=ZoneInfo(key="UTC")), - datetime(2022, 1, 1, 6, 0, tzinfo=ZoneInfo(key="UTC")), + datetime(2021, 12, 31, 23, 0, tzinfo=ZoneInfo("UTC")), + datetime(2022, 1, 1, 0, 0, tzinfo=ZoneInfo("UTC")), + datetime(2022, 1, 1, 1, 0, tzinfo=ZoneInfo("UTC")), + datetime(2022, 1, 1, 2, 0, tzinfo=ZoneInfo("UTC")), + datetime(2022, 1, 1, 3, 0, tzinfo=ZoneInfo("UTC")), + datetime(2022, 1, 1, 4, 0, tzinfo=ZoneInfo("UTC")), + datetime(2022, 1, 1, 5, 0, tzinfo=ZoneInfo("UTC")), + datetime(2022, 1, 1, 6, 0, tzinfo=ZoneInfo("UTC")), ], "CST": [ - datetime(2021, 12, 31, 17, 0, tzinfo=ZoneInfo(key="US/Central")), - datetime(2021, 12, 31, 18, 0, tzinfo=ZoneInfo(key="US/Central")), - datetime(2021, 12, 31, 19, 0, tzinfo=ZoneInfo(key="US/Central")), - datetime(2021, 12, 31, 20, 0, tzinfo=ZoneInfo(key="US/Central")), - datetime(2021, 12, 31, 21, 0, tzinfo=ZoneInfo(key="US/Central")), - datetime(2021, 12, 31, 22, 0, tzinfo=ZoneInfo(key="US/Central")), - datetime(2021, 12, 31, 23, 0, tzinfo=ZoneInfo(key="US/Central")), - datetime(2022, 1, 1, 0, 0, tzinfo=ZoneInfo(key="US/Central")), + datetime(2021, 12, 31, 17, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2021, 12, 31, 18, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2021, 12, 31, 19, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2021, 12, 31, 20, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2021, 12, 31, 21, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2021, 12, 31, 22, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2021, 12, 31, 23, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2022, 1, 1, 0, 0, tzinfo=ZoneInfo("US/Central")), ], "CST truncated": [ - datetime(2021, 12, 31, 0, 0, tzinfo=ZoneInfo(key="US/Central")), - datetime(2021, 12, 31, 0, 0, tzinfo=ZoneInfo(key="US/Central")), - datetime(2021, 12, 31, 0, 0, tzinfo=ZoneInfo(key="US/Central")), - datetime(2021, 12, 31, 0, 0, tzinfo=ZoneInfo(key="US/Central")), - datetime(2021, 12, 31, 0, 0, tzinfo=ZoneInfo(key="US/Central")), - datetime(2021, 12, 31, 0, 0, tzinfo=ZoneInfo(key="US/Central")), - datetime(2021, 12, 31, 0, 0, tzinfo=ZoneInfo(key="US/Central")), - datetime(2022, 1, 1, 0, 0, tzinfo=ZoneInfo(key="US/Central")), + datetime(2021, 12, 31, 0, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2021, 12, 31, 0, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2021, 12, 31, 0, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2021, 12, 31, 0, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2021, 12, 31, 0, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2021, 12, 31, 0, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2021, 12, 31, 0, 0, tzinfo=ZoneInfo("US/Central")), + datetime(2022, 1, 1, 0, 0, tzinfo=ZoneInfo("US/Central")), ], } @@ -1947,10 +2012,10 @@ def test_tz_aware_to_string() -> None: result = df.with_columns(pl.col("dt").dt.to_string("%c").alias("fmt")) expected = { "dt": [ - datetime(2022, 11, 1, 0, 0, tzinfo=ZoneInfo(key="America/New_York")), - datetime(2022, 11, 2, 0, 0, tzinfo=ZoneInfo(key="America/New_York")), - datetime(2022, 11, 3, 0, 0, tzinfo=ZoneInfo(key="America/New_York")), - datetime(2022, 11, 4, 0, 0, tzinfo=ZoneInfo(key="America/New_York")), + datetime(2022, 11, 1, 0, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(2022, 11, 2, 0, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(2022, 11, 3, 0, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(2022, 11, 4, 0, 0, tzinfo=ZoneInfo("America/New_York")), ], "fmt": [ "Tue Nov 1 00:00:00 2022", @@ -2008,12 +2073,12 @@ def test_tz_aware_filter_lit() -> None: datetime(1970, 1, 1, 5, 0), ], "nyc": [ - datetime(1970, 1, 1, 0, 0, tzinfo=ZoneInfo(key="America/New_York")), - datetime(1970, 1, 1, 1, 0, tzinfo=ZoneInfo(key="America/New_York")), - datetime(1970, 1, 1, 2, 0, tzinfo=ZoneInfo(key="America/New_York")), - datetime(1970, 1, 1, 3, 0, tzinfo=ZoneInfo(key="America/New_York")), - datetime(1970, 1, 1, 4, 0, tzinfo=ZoneInfo(key="America/New_York")), - datetime(1970, 1, 1, 5, 0, tzinfo=ZoneInfo(key="America/New_York")), + datetime(1970, 1, 1, 0, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(1970, 1, 1, 1, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(1970, 1, 1, 2, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(1970, 1, 1, 3, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(1970, 1, 1, 4, 0, tzinfo=ZoneInfo("America/New_York")), + datetime(1970, 1, 1, 5, 0, tzinfo=ZoneInfo("America/New_York")), ], } @@ -2088,26 +2153,26 @@ def test_truncate_expr() -> None: ambiguous_expr = df.select(pl.col("date").dt.truncate(every=pl.lit("30m"))) assert ambiguous_expr.to_dict(as_series=False) == { "date": [ - datetime(2020, 10, 25, tzinfo=ZoneInfo(key="Europe/London")), - datetime(2020, 10, 25, 0, 30, tzinfo=ZoneInfo(key="Europe/London")), - datetime(2020, 10, 25, 1, 0, tzinfo=ZoneInfo(key="Europe/London")), - datetime(2020, 10, 25, 1, 30, tzinfo=ZoneInfo(key="Europe/London")), - datetime(2020, 10, 25, 1, 0, tzinfo=ZoneInfo(key="Europe/London")), - datetime(2020, 10, 25, 1, 30, tzinfo=ZoneInfo(key="Europe/London")), - datetime(2020, 10, 25, 2, 0, tzinfo=ZoneInfo(key="Europe/London")), + datetime(2020, 10, 25, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 0, 30, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 1, 0, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 1, 30, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 1, 0, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 1, 30, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 2, 0, tzinfo=ZoneInfo("Europe/London")), ] } all_expr = df.select(pl.col("date").dt.truncate(every=pl.col("every"))) assert all_expr.to_dict(as_series=False) == { "date": [ - datetime(2020, 10, 25, tzinfo=ZoneInfo(key="Europe/London")), - datetime(2020, 10, 25, 0, 45, tzinfo=ZoneInfo(key="Europe/London")), - datetime(2020, 10, 25, 1, 0, tzinfo=ZoneInfo(key="Europe/London")), - datetime(2020, 10, 25, 1, 45, tzinfo=ZoneInfo(key="Europe/London")), - datetime(2020, 10, 25, 1, 0, tzinfo=ZoneInfo(key="Europe/London")), - datetime(2020, 10, 25, 1, 45, tzinfo=ZoneInfo(key="Europe/London")), - datetime(2020, 10, 25, 2, 0, tzinfo=ZoneInfo(key="Europe/London")), + datetime(2020, 10, 25, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 0, 45, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 1, 0, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 1, 45, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 1, 0, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 1, 45, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 2, 0, tzinfo=ZoneInfo("Europe/London")), ] } @@ -2287,6 +2352,13 @@ def test_truncate_ambiguous() -> None: assert_series_equal(result, expected) +def test_truncate_non_existent_14957() -> None: + with pytest.raises(ComputeError, match="non-existent"): + pl.Series([datetime(2020, 3, 29, 2, 1)]).dt.replace_time_zone( + "Europe/London" + ).dt.truncate("46m") + + def test_round_ambiguous() -> None: t = ( pl.datetime_range( @@ -2343,13 +2415,13 @@ def test_round_ambiguous() -> None: df = df.select(pl.col("date").dt.round("30m", ambiguous=pl.col("ambiguous"))) assert df.to_dict(as_series=False) == { "date": [ - datetime(2020, 10, 25, 0, 30, tzinfo=ZoneInfo(key="Europe/London")), - datetime(2020, 10, 25, 1, tzinfo=ZoneInfo(key="Europe/London")), - datetime(2020, 10, 25, 1, 30, tzinfo=ZoneInfo(key="Europe/London")), - datetime(2020, 10, 25, 1, tzinfo=ZoneInfo(key="Europe/London")), - datetime(2020, 10, 25, 1, 30, tzinfo=ZoneInfo(key="Europe/London")), - datetime(2020, 10, 25, 2, tzinfo=ZoneInfo(key="Europe/London")), - datetime(2020, 10, 25, 2, 30, tzinfo=ZoneInfo(key="Europe/London")), + datetime(2020, 10, 25, 0, 30, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 1, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 1, 30, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 1, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 1, 30, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 2, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 25, 2, 30, tzinfo=ZoneInfo("Europe/London")), ] } @@ -2652,3 +2724,9 @@ def test_rolling_duplicates() -> None: assert df.sort("ts").with_columns(pl.col("value").rolling_max("1d", by="ts"))[ "value" ].to_list() == [1, 1] + + +def test_datetime_time_unit_none_deprecated() -> None: + with pytest.deprecated_call(): + dtype = pl.Datetime(time_unit=None) # type: ignore[arg-type] + assert dtype.time_unit == "us" diff --git a/py-polars/tests/unit/expr/test_exprs.py b/py-polars/tests/unit/expr/test_exprs.py index 10b5b34cb8a35..08ba418434951 100644 --- a/py-polars/tests/unit/expr/test_exprs.py +++ b/py-polars/tests/unit/expr/test_exprs.py @@ -1,16 +1,8 @@ from __future__ import annotations -import sys from datetime import date, datetime, timedelta, timezone from itertools import permutations -from typing import Any, cast - -if sys.version_info >= (3, 9): - from zoneinfo import ZoneInfo -else: - # Import from submodule due to typing issue with backports.zoneinfo package: - # https://github.com/pganssle/zoneinfo/issues/125 - from backports.zoneinfo._zoneinfo import ZoneInfo +from typing import TYPE_CHECKING, Any, cast import numpy as np import pytest @@ -26,6 +18,11 @@ ) from polars.testing import assert_frame_equal, assert_series_equal +if TYPE_CHECKING: + from zoneinfo import ZoneInfo +else: + from polars._utils.convert import string_to_zoneinfo as ZoneInfo + def test_arg_true() -> None: df = pl.DataFrame({"a": [1, 1, 2, 1]}) @@ -430,34 +427,6 @@ def test_logical_boolean() -> None: df.select([(pl.col("a") > pl.col("b")) or (pl.col("b") > pl.col("b"))]) -# https://github.com/pola-rs/polars/issues/4951 -def test_ewm_with_multiple_chunks() -> None: - df0 = pl.DataFrame( - data=[ - ("w", 6.0, 1.0), - ("x", 5.0, 2.0), - ("y", 4.0, 3.0), - ("z", 3.0, 4.0), - ], - schema=["a", "b", "c"], - ).with_columns( - [ - pl.col(pl.Float64).log().diff().name.prefix("ld_"), - ] - ) - assert df0.n_chunks() == 1 - - # NOTE: We aren't testing whether `select` creates two chunks; - # we just need two chunks to properly test `ewm_mean` - df1 = df0.select(["ld_b", "ld_c"]) - assert df1.n_chunks() == 2 - - ewm_std = df1.with_columns( - pl.all().ewm_std(com=20).name.prefix("ewm_"), - ) - assert ewm_std.null_count().sum_horizontal()[0] == 4 - - def test_lit_dtypes() -> None: def lit_series(value: Any, dtype: pl.PolarsDataType | None) -> pl.Series: return pl.select(pl.lit(value, dtype=dtype)).to_series() @@ -477,7 +446,7 @@ def lit_series(value: Any, dtype: pl.PolarsDataType | None) -> pl.Series: "dtm_aware_0": lit_series(d, pl.Datetime("us", "Asia/Kathmandu")), "dtm_aware_1": lit_series(d_tz, pl.Datetime("us")), "dtm_aware_2": lit_series(d_tz, None), - "dtm_aware_3": lit_series(d, pl.Datetime(None, "Asia/Kathmandu")), + "dtm_aware_3": lit_series(d, pl.Datetime(time_zone="Asia/Kathmandu")), "dur_ms": lit_series(td, pl.Duration("ms")), "dur_us": lit_series(td, pl.Duration("us")), "dur_ns": lit_series(td, pl.Duration("ns")), diff --git a/py-polars/tests/unit/functions/as_datatype/test_as_datatype.py b/py-polars/tests/unit/functions/as_datatype/test_as_datatype.py index c1e266933f841..2074865f694c8 100644 --- a/py-polars/tests/unit/functions/as_datatype/test_as_datatype.py +++ b/py-polars/tests/unit/functions/as_datatype/test_as_datatype.py @@ -13,7 +13,7 @@ from polars.type_aliases import TimeUnit else: - from polars.utils.convert import get_zoneinfo as ZoneInfo + from polars._utils.convert import string_to_zoneinfo as ZoneInfo def test_date_datetime() -> None: diff --git a/py-polars/tests/unit/functions/range/test_datetime_range.py b/py-polars/tests/unit/functions/range/test_datetime_range.py index 70d1cdda0c047..b9873f833b4bf 100644 --- a/py-polars/tests/unit/functions/range/test_datetime_range.py +++ b/py-polars/tests/unit/functions/range/test_datetime_range.py @@ -16,7 +16,7 @@ from polars.datatypes import PolarsDataType from polars.type_aliases import ClosedInterval, TimeUnit else: - from polars.utils.convert import get_zoneinfo as ZoneInfo + from polars._utils.convert import string_to_zoneinfo as ZoneInfo def test_datetime_range() -> None: @@ -166,13 +166,13 @@ def test_timezone_aware_datetime_range() -> None: assert pl.datetime_range( low, high, interval=timedelta(days=5), eager=True ).to_list() == [ - datetime(2022, 10, 17, 10, 0, tzinfo=ZoneInfo(key="Asia/Shanghai")), - datetime(2022, 10, 22, 10, 0, tzinfo=ZoneInfo(key="Asia/Shanghai")), - datetime(2022, 10, 27, 10, 0, tzinfo=ZoneInfo(key="Asia/Shanghai")), - datetime(2022, 11, 1, 10, 0, tzinfo=ZoneInfo(key="Asia/Shanghai")), - datetime(2022, 11, 6, 10, 0, tzinfo=ZoneInfo(key="Asia/Shanghai")), - datetime(2022, 11, 11, 10, 0, tzinfo=ZoneInfo(key="Asia/Shanghai")), - datetime(2022, 11, 16, 10, 0, tzinfo=ZoneInfo(key="Asia/Shanghai")), + datetime(2022, 10, 17, 10, 0, tzinfo=ZoneInfo("Asia/Shanghai")), + datetime(2022, 10, 22, 10, 0, tzinfo=ZoneInfo("Asia/Shanghai")), + datetime(2022, 10, 27, 10, 0, tzinfo=ZoneInfo("Asia/Shanghai")), + datetime(2022, 11, 1, 10, 0, tzinfo=ZoneInfo("Asia/Shanghai")), + datetime(2022, 11, 6, 10, 0, tzinfo=ZoneInfo("Asia/Shanghai")), + datetime(2022, 11, 11, 10, 0, tzinfo=ZoneInfo("Asia/Shanghai")), + datetime(2022, 11, 16, 10, 0, tzinfo=ZoneInfo("Asia/Shanghai")), ] with pytest.raises( diff --git a/py-polars/tests/unit/functions/test_cum_count.py b/py-polars/tests/unit/functions/test_cum_count.py index bbedad60d5986..b68db3678ab9e 100644 --- a/py-polars/tests/unit/functions/test_cum_count.py +++ b/py-polars/tests/unit/functions/test_cum_count.py @@ -15,11 +15,13 @@ def test_cum_count_no_args(reverse: bool, output: list[int]) -> None: assert_frame_equal(result, expected) -def test_cum_count_single_arg() -> None: +@pytest.mark.parametrize(("reverse", "output"), [(False, [1, 2, 2]), (True, [2, 1, 0])]) +def test_cum_count_single_arg(reverse: bool, output: list[int]) -> None: df = pl.DataFrame({"a": [5, 5, None]}) - result = df.select(pl.cum_count("a")) - expected = pl.Series("a", [1, 2, 2], dtype=pl.UInt32).to_frame() + result = df.select(pl.cum_count("a", reverse=reverse)) + expected = pl.Series("a", output, dtype=pl.UInt32).to_frame() assert_frame_equal(result, expected) + assert result.to_series().flags[("SORTED_ASC", "SORTED_DESC")[reverse]] def test_cum_count_multi_arg() -> None: diff --git a/py-polars/tests/unit/functions/test_when_then.py b/py-polars/tests/unit/functions/test_when_then.py index 10c0602b47c34..b5aec7177072c 100644 --- a/py-polars/tests/unit/functions/test_when_then.py +++ b/py-polars/tests/unit/functions/test_when_then.py @@ -242,6 +242,18 @@ def test_comp_categorical_lit_dtype() -> None: ).dtypes == [pl.Categorical, pl.Int32] +def test_comp_incompatible_enum_dtype() -> None: + df = pl.DataFrame({"a": pl.Series(["a", "b"], dtype=pl.Enum(["a", "b"]))}) + + with pytest.raises( + pl.ComputeError, + match="conversion from `str` to `enum` failed in column 'literal'", + ): + df.with_columns( + pl.when(pl.col("a") == "a").then(pl.col("a")).otherwise(pl.lit("c")) + ) + + def test_predicate_broadcast() -> None: df = pl.DataFrame( { diff --git a/py-polars/tests/unit/interop/numpy/test_to_numpy_df.py b/py-polars/tests/unit/interop/numpy/test_to_numpy_df.py index 2b41519777e6c..68b1f7696dd7a 100644 --- a/py-polars/tests/unit/interop/numpy/test_to_numpy_df.py +++ b/py-polars/tests/unit/interop/numpy/test_to_numpy_df.py @@ -121,19 +121,19 @@ def test_df_to_numpy_decimal(use_pyarrow: bool) -> None: assert_array_equal(result, expected) -def test_to_numpy_zero_copy_path() -> None: +def test_df_to_numpy_zero_copy_path() -> None: rows = 10 cols = 5 x = np.ones((rows, cols), order="F") x[:, 1] = 2.0 df = pl.DataFrame(x) - x = df.to_numpy() + x = df.to_numpy(allow_copy=False) assert x.flags["F_CONTIGUOUS"] assert not x.flags["WRITEABLE"] assert str(x[0, :]) == "[1. 2. 1. 1. 1.]" -def test_to_numpy_zero_copy_path_writeable() -> None: +def test_to_numpy_zero_copy_path_writable() -> None: rows = 10 cols = 5 x = np.ones((rows, cols), order="F") @@ -141,3 +141,23 @@ def test_to_numpy_zero_copy_path_writeable() -> None: df = pl.DataFrame(x) x = df.to_numpy(writable=True) assert x.flags["WRITEABLE"] + + +def test_df_to_numpy_structured_not_zero_copy() -> None: + df = pl.DataFrame({"a": [1, 2]}) + msg = "cannot create structured array without copying data" + with pytest.raises(RuntimeError, match=msg): + df.to_numpy(structured=True, allow_copy=False) + + +def test_df_to_numpy_writable_not_zero_copy() -> None: + df = pl.DataFrame({"a": [1, 2]}) + msg = "cannot create writable array without copying data" + with pytest.raises(RuntimeError, match=msg): + df.to_numpy(allow_copy=False, writable=True) + + +def test_df_to_numpy_not_zero_copy() -> None: + df = pl.DataFrame({"a": [1, 2, None]}) + with pytest.raises(RuntimeError): + df.to_numpy(allow_copy=False) diff --git a/py-polars/tests/unit/interop/numpy/test_to_numpy_series.py b/py-polars/tests/unit/interop/numpy/test_to_numpy_series.py index 0f980e4fe2dd2..fe2909672fa8f 100644 --- a/py-polars/tests/unit/interop/numpy/test_to_numpy_series.py +++ b/py-polars/tests/unit/interop/numpy/test_to_numpy_series.py @@ -25,9 +25,9 @@ def assert_zero_copy(s: pl.Series, arr: np.ndarray[Any, Any]) -> None: assert s_ptr == arr_ptr -def assert_zero_copy_only_raises(s: pl.Series) -> None: +def assert_allow_copy_false_raises(s: pl.Series) -> None: with pytest.raises(ValueError, match="cannot return a zero-copy array"): - s.to_numpy(use_pyarrow=False, zero_copy_only=True) + s.to_numpy(use_pyarrow=False, allow_copy=False) @pytest.mark.parametrize( @@ -48,8 +48,8 @@ def assert_zero_copy_only_raises(s: pl.Series) -> None: def test_series_to_numpy_numeric_zero_copy( dtype: pl.PolarsDataType, expected_dtype: npt.DTypeLike ) -> None: - s = pl.Series([1, 2, 3]).cast(dtype) # =dtype, strict=False) - result = s.to_numpy(use_pyarrow=False, zero_copy_only=True) + s = pl.Series([1, 2, 3]).cast(dtype) + result = s.to_numpy(use_pyarrow=False, allow_copy=False) assert_zero_copy(s, result) assert result.tolist() == s.to_list() @@ -80,7 +80,7 @@ def test_series_to_numpy_numeric_with_nulls( assert result.tolist()[:-1] == s.to_list()[:-1] assert np.isnan(result[-1]) assert result.dtype == expected_dtype - assert_zero_copy_only_raises(s) + assert_allow_copy_false_raises(s) @pytest.mark.parametrize( @@ -101,7 +101,7 @@ def test_series_to_numpy_temporal_zero_copy( ) -> None: values = [0, 2_000, 1_000_000] s = pl.Series(values, dtype=dtype, strict=False) - result = s.to_numpy(use_pyarrow=False, zero_copy_only=True) + result = s.to_numpy(use_pyarrow=False, allow_copy=False) assert_zero_copy(s, result) # NumPy tolist returns integers for ns precision @@ -115,7 +115,7 @@ def test_series_to_numpy_temporal_zero_copy( def test_series_to_numpy_datetime_with_tz_zero_copy() -> None: values = [datetime(1970, 1, 1), datetime(2024, 2, 28)] s = pl.Series(values).dt.convert_time_zone("Europe/Amsterdam") - result = s.to_numpy(use_pyarrow=False, zero_copy_only=True) + result = s.to_numpy(use_pyarrow=False, allow_copy=False) assert_zero_copy(s, result) assert result.tolist() == values @@ -130,7 +130,7 @@ def test_series_to_numpy_date() -> None: assert s.to_list() == result.tolist() assert result.dtype == np.dtype("datetime64[D]") - assert_zero_copy_only_raises(s) + assert_allow_copy_false_raises(s) @pytest.mark.parametrize( @@ -159,7 +159,7 @@ def test_series_to_numpy_temporal_with_nulls( else: assert result.tolist() == s.to_list() assert result.dtype == expected_dtype - assert_zero_copy_only_raises(s) + assert_allow_copy_false_raises(s) def test_series_to_numpy_datetime_with_tz_with_nulls() -> None: @@ -169,7 +169,7 @@ def test_series_to_numpy_datetime_with_tz_with_nulls() -> None: assert result.tolist() == values assert result.dtype == np.dtype("datetime64[us]") - assert_zero_copy_only_raises(s) + assert_allow_copy_false_raises(s) @pytest.mark.parametrize( @@ -199,7 +199,7 @@ def test_to_numpy_object_dtypes( assert result.tolist() == values assert result.dtype == np.object_ - assert_zero_copy_only_raises(s) + assert_allow_copy_false_raises(s) def test_series_to_numpy_bool() -> None: @@ -208,7 +208,7 @@ def test_series_to_numpy_bool() -> None: assert s.to_list() == result.tolist() assert result.dtype == np.bool_ - assert_zero_copy_only_raises(s) + assert_allow_copy_false_raises(s) def test_series_to_numpy_bool_with_nulls() -> None: @@ -217,7 +217,7 @@ def test_series_to_numpy_bool_with_nulls() -> None: assert s.to_list() == result.tolist() assert result.dtype == np.object_ - assert_zero_copy_only_raises(s) + assert_allow_copy_false_raises(s) def test_series_to_numpy_array_of_int() -> None: @@ -249,7 +249,7 @@ def test_series_to_numpy_array_with_nulls() -> None: expected = np.array([[1.0, 2.0], [3.0, 4.0], [np.nan, np.nan]]) assert_array_equal(result, expected) assert result.dtype == np.float64 - assert_zero_copy_only_raises(s) + assert_allow_copy_false_raises(s) def test_to_numpy_null() -> None: @@ -258,12 +258,12 @@ def test_to_numpy_null() -> None: expected = np.array([np.nan, np.nan], dtype=np.float32) assert_array_equal(result, expected) assert result.dtype == np.float32 - assert_zero_copy_only_raises(s) + assert_allow_copy_false_raises(s) def test_to_numpy_empty() -> None: s = pl.Series(dtype=pl.String) - result = s.to_numpy(use_pyarrow=False, zero_copy_only=True) + result = s.to_numpy(use_pyarrow=False, allow_copy=False) assert result.dtype == np.object_ assert result.shape == (0,) assert result.size == 0 @@ -278,7 +278,15 @@ def test_to_numpy_chunked() -> None: assert result.tolist() == s.to_list() assert result.dtype == np.int64 - assert_zero_copy_only_raises(s) + assert_allow_copy_false_raises(s) + + +def test_zero_copy_only_deprecated() -> None: + values = [1, 2] + s = pl.Series([1, 2]) + with pytest.deprecated_call(): + result = s.to_numpy(zero_copy_only=True) + assert result.tolist() == values def test_series_to_numpy_temporal() -> None: @@ -372,7 +380,8 @@ def test_to_numpy2( def test_view() -> None: s = pl.Series("a", [1.0, 2.5, 3.0]) - result = s._view() + with pytest.deprecated_call(): + result = s.view() assert isinstance(result, np.ndarray) assert np.all(result == np.array([1.0, 2.5, 3.0])) @@ -380,27 +389,22 @@ def test_view() -> None: def test_view_nulls() -> None: s = pl.Series("b", [1, 2, None]) assert s.has_validity() - with pytest.raises(AssertionError): - s._view() + with pytest.deprecated_call(), pytest.raises(AssertionError): + s.view() def test_view_nulls_sliced() -> None: s = pl.Series("b", [1, 2, None]) sliced = s[:2] - assert np.all(sliced._view() == np.array([1, 2])) + with pytest.deprecated_call(): + view = sliced.view() + assert np.all(view == np.array([1, 2])) assert not sliced.has_validity() def test_view_ub() -> None: # this would be UB if the series was dropped and not passed to the view s = pl.Series([3, 1, 5]) - result = s.sort()._view() - assert np.sum(result) == 9 - - -def test_view_deprecated() -> None: - s = pl.Series("a", [1.0, 2.5, 3.0]) with pytest.deprecated_call(): - result = s.view() - assert isinstance(result, np.ndarray) - assert np.all(result == np.array([1.0, 2.5, 3.0])) + result = s.sort().view() + assert np.sum(result) == 9 diff --git a/py-polars/tests/unit/io/cloud/test_aws.py b/py-polars/tests/unit/io/cloud/test_aws.py index 03fb72b7eff44..652be4257658c 100644 --- a/py-polars/tests/unit/io/cloud/test_aws.py +++ b/py-polars/tests/unit/io/cloud/test_aws.py @@ -1,5 +1,6 @@ from __future__ import annotations +import multiprocessing from typing import TYPE_CHECKING, Any, Callable, Iterator import boto3 @@ -7,6 +8,7 @@ from moto.server import ThreadedMotoServer import polars as pl +from polars.testing import assert_frame_equal if TYPE_CHECKING: from pathlib import Path @@ -33,12 +35,14 @@ def s3_base(monkeypatch_module: Any) -> Iterator[str]: host = "127.0.0.1" port = 5000 moto_server = ThreadedMotoServer(host, port) - - moto_server.start() + # Start in a separate process to avoid deadlocks + mp = multiprocessing.get_context("spawn") + p = mp.Process(target=moto_server._server_entry, daemon=True) + p.start() print("server up") yield f"http://{host}:{port}" print("moto done") - moto_server.stop() + p.kill() @pytest.fixture() @@ -47,7 +51,7 @@ def s3(s3_base: str, io_files_path: Path) -> str: client = boto3.client("s3", region_name=region, endpoint_url=s3_base) client.create_bucket(Bucket="bucket") - files = ["foods1.csv", "foods1.ipc", "foods1.parquet"] + files = ["foods1.csv", "foods1.ipc", "foods1.parquet", "foods2.parquet"] for file in files: client.upload_file(io_files_path / file, Bucket="bucket", Key=file) return s3_base @@ -55,10 +59,7 @@ def s3(s3_base: str, io_files_path: Path) -> str: @pytest.mark.parametrize( ("function", "extension"), - [ - (pl.read_csv, "csv"), - (pl.read_ipc, "ipc"), - ], + [(pl.read_csv, "csv"), (pl.read_ipc, "ipc")], ) def test_read_s3(s3: str, function: Callable[..., Any], extension: str) -> None: df = function( @@ -71,9 +72,7 @@ def test_read_s3(s3: str, function: Callable[..., Any], extension: str) -> None: @pytest.mark.parametrize( ("function", "extension"), - [ - (pl.scan_ipc, "ipc"), - ], + [(pl.scan_ipc, "ipc"), (pl.scan_parquet, "parquet")], ) def test_scan_s3(s3: str, function: Callable[..., Any], extension: str) -> None: df = function( @@ -82,3 +81,13 @@ def test_scan_s3(s3: str, function: Callable[..., Any], extension: str) -> None: ) assert df.columns == ["category", "calories", "fats_g", "sugars_g"] assert df.collect().shape == (27, 4) + + +def test_lazy_count_s3(s3: str) -> None: + lf = pl.scan_parquet( + "s3://bucket/foods*.parquet", storage_options={"endpoint_url": s3} + ).select(pl.len()) + + assert "FAST COUNT(*)" in lf.explain() + expected = pl.DataFrame({"len": [54]}, schema={"len": pl.UInt32}) + assert_frame_equal(lf.collect(), expected) diff --git a/py-polars/tests/unit/io/files/graph-data/follows.csv b/py-polars/tests/unit/io/files/graph-data/follows.csv new file mode 100644 index 0000000000000..5ec090c283cd5 --- /dev/null +++ b/py-polars/tests/unit/io/files/graph-data/follows.csv @@ -0,0 +1,4 @@ +Adam,Karissa,2020 +Adam,Zhang,2020 +Karissa,Zhang,2021 +Zhang,Noura,2022 diff --git a/py-polars/tests/unit/io/files/graph-data/user.csv b/py-polars/tests/unit/io/files/graph-data/user.csv new file mode 100644 index 0000000000000..0421e38ee559f --- /dev/null +++ b/py-polars/tests/unit/io/files/graph-data/user.csv @@ -0,0 +1,4 @@ +Adam,30 +Karissa,40 +Zhang,50 +Noura,25 diff --git a/py-polars/tests/unit/io/test_csv.py b/py-polars/tests/unit/io/test_csv.py index 463e3b1364a07..1273844413573 100644 --- a/py-polars/tests/unit/io/test_csv.py +++ b/py-polars/tests/unit/io/test_csv.py @@ -14,9 +14,10 @@ import zstandard import polars as pl +from polars._utils.various import normalize_filepath from polars.exceptions import ComputeError, NoDataError +from polars.io.csv import BatchedCsvReader from polars.testing import assert_frame_equal, assert_series_equal -from polars.utils.various import normalize_filepath if TYPE_CHECKING: from pathlib import Path @@ -1414,8 +1415,9 @@ def test_csv_categorical_categorical_merge() -> None: def test_batched_csv_reader(foods_file_path: Path) -> None: reader = pl.read_csv_batched(foods_file_path, batch_size=4) - batches = reader.next_batches(5) + assert isinstance(reader, BatchedCsvReader) + batches = reader.next_batches(5) assert batches is not None assert len(batches) == 5 assert batches[0].to_dict(as_series=False) == { @@ -1431,10 +1433,12 @@ def test_batched_csv_reader(foods_file_path: Path) -> None: "sugars_g": [25, 0, 5, 11], } assert_frame_equal(pl.concat(batches), pl.read_csv(foods_file_path)) + # the final batch of the low-memory variant is different reader = pl.read_csv_batched(foods_file_path, batch_size=4, low_memory=True) batches = reader.next_batches(5) assert len(batches) == 5 # type: ignore[arg-type] + batches += reader.next_batches(5) # type: ignore[operator] assert_frame_equal(pl.concat(batches), pl.read_csv(foods_file_path)) @@ -1476,6 +1480,11 @@ def test_batched_csv_reader_no_batches(foods_file_path: Path) -> None: assert batches is None +def test_read_csv_batched_invalid_source() -> None: + with pytest.raises(TypeError): + pl.read_csv_batched(source=5) # type: ignore[arg-type] + + def test_csv_single_categorical_null() -> None: f = io.BytesIO() pl.DataFrame( @@ -1671,7 +1680,7 @@ def test_write_csv_stdout_stderr(capsys: pytest.CaptureFixture[str]) -> None: ) # pytest hijacks sys.stdout and changes its type, which causes mypy failure - df.write_csv(sys.stdout) # type: ignore[call-overload] + df.write_csv(sys.stdout) captured = capsys.readouterr() assert captured.out == ( "numbers,strings,dates\n" @@ -1680,7 +1689,7 @@ def test_write_csv_stdout_stderr(capsys: pytest.CaptureFixture[str]) -> None: "3,stdout,2023-01-03\n" ) - df.write_csv(sys.stderr) # type: ignore[call-overload] + df.write_csv(sys.stderr) captured = capsys.readouterr() assert captured.err == ( "numbers,strings,dates\n" @@ -1820,7 +1829,7 @@ def test_provide_schema() -> None: } -def test_custom_writeable_object() -> None: +def test_custom_writable_object() -> None: df = pl.DataFrame({"a": [10, 20, 30], "b": ["x", "y", "z"]}) class CustomBuffer: @@ -1955,3 +1964,8 @@ def test_read_csv_single_column(columns: list[str] | str) -> None: df = pl.read_csv(f, columns=columns) expected = pl.DataFrame({"b": [2, 5]}) assert_frame_equal(df, expected) + + +def test_csv_invalid_escape_utf8_14960() -> None: + with pytest.raises(pl.ComputeError, match=r"field is not properly escaped"): + pl.read_csv('col1\n""•'.encode()) diff --git a/py-polars/tests/unit/io/test_database_read.py b/py-polars/tests/unit/io/test_database_read.py index 7b4ad1d8bc328..fde1b294d6f02 100644 --- a/py-polars/tests/unit/io/test_database_read.py +++ b/py-polars/tests/unit/io/test_database_read.py @@ -9,18 +9,18 @@ from types import GeneratorType from typing import TYPE_CHECKING, Any, NamedTuple +import pyarrow as pa import pytest from sqlalchemy import Integer, MetaData, Table, create_engine, func, select +from sqlalchemy.orm import sessionmaker from sqlalchemy.sql.expression import cast as alchemy_cast import polars as pl -from polars.exceptions import UnsuitableSQLError +from polars.exceptions import ComputeError, UnsuitableSQLError from polars.io.database import _ARROW_DRIVER_REGISTRY_ from polars.testing import assert_frame_equal if TYPE_CHECKING: - import pyarrow as pa - from polars.type_aliases import ( ConnectionOrCursor, DbReadEngine, @@ -33,24 +33,26 @@ def adbc_sqlite_connect(*args: Any, **kwargs: Any) -> Any: with suppress(ModuleNotFoundError): # not available on 3.8/windows from adbc_driver_sqlite.dbapi import connect + args = tuple(str(a) if isinstance(a, Path) else a for a in args) return connect(*args, **kwargs) -def create_temp_sqlite_db(test_db: str) -> None: - Path(test_db).unlink(missing_ok=True) +@pytest.fixture() +def tmp_sqlite_db(tmp_path: Path) -> Path: + test_db = tmp_path / "test.db" + test_db.unlink(missing_ok=True) def convert_date(val: bytes) -> date: """Convert ISO 8601 date to datetime.date object.""" return date.fromisoformat(val.decode()) - sqlite3.register_converter("date", convert_date) - # NOTE: at the time of writing adcb/connectorx have weak SQLite support (poor or # no bool/date/datetime dtypes, for example) and there is a bug in connectorx that # causes float rounding < py 3.11, hence we are only testing/storing simple values # in this test db for now. as support improves, we can add/test additional dtypes). - + sqlite3.register_converter("date", convert_date) conn = sqlite3.connect(test_db) + # ┌─────┬───────┬───────┬────────────┐ # │ id ┆ name ┆ value ┆ date │ # │ --- ┆ --- ┆ --- ┆ --- │ @@ -61,17 +63,34 @@ def convert_date(val: bytes) -> date: # └─────┴───────┴───────┴────────────┘ conn.executescript( """ - CREATE TABLE test_data ( + CREATE TABLE IF NOT EXISTS test_data ( id INTEGER PRIMARY KEY, name TEXT NOT NULL, value FLOAT, date DATE ); - INSERT INTO test_data(name,value,date) - VALUES ('misc',100.0,'2020-01-01'), ('other',-99.5,'2021-12-31'); + REPLACE INTO test_data(name,value,date) + VALUES ('misc',100.0,'2020-01-01'), + ('other',-99.5,'2021-12-31'); """ ) conn.close() + return test_db + + +@pytest.fixture() +def tmp_sqlite_inference_db(tmp_path: Path) -> Path: + test_db = tmp_path / "test_inference.db" + test_db.unlink(missing_ok=True) + conn = sqlite3.connect(test_db) + conn.executescript( + """ + CREATE TABLE IF NOT EXISTS test_data (name TEXT, value FLOAT); + REPLACE INTO test_data(name,value) VALUES (NULL,NULL), ('foo',0); + """ + ) + conn.close() + return test_db class DatabaseReadTestParams(NamedTuple): @@ -115,10 +134,10 @@ def __init__( test_data=test_data, ) - def close(self) -> None: # noqa: D102 + def close(self) -> None: pass - def cursor(self) -> Any: # noqa: D102 + def cursor(self) -> Any: return self._cursor @@ -142,10 +161,10 @@ def __getattr__(self, item: str) -> Any: return self.resultset super().__getattr__(item) # type: ignore[misc] - def close(self) -> Any: # noqa: D102 + def close(self) -> Any: pass - def execute(self, query: str) -> Any: # noqa: D102 + def execute(self, query: str) -> Any: return self @@ -160,7 +179,7 @@ def __init__( self.batched = batched self.n_calls = 1 - def __call__(self, *args: Any, **kwargs: Any) -> Any: # noqa: D102 + def __call__(self, *args: Any, **kwargs: Any) -> Any: if self.repeat_batched_calls: res = self.test_data[: None if self.n_calls else 0] self.n_calls -= 1 @@ -312,49 +331,71 @@ def test_read_database( expected_dates: list[date | str], schema_overrides: SchemaDict | None, batch_size: int | None, - tmp_path: Path, + tmp_sqlite_db: Path, ) -> None: - tmp_path.mkdir(exist_ok=True) - test_db = str(tmp_path / "test.db") - create_temp_sqlite_db(test_db) - if read_method == "read_database_uri": # instantiate the connection ourselves, using connectorx/adbc df = pl.read_database_uri( - uri=f"sqlite:///{test_db}", + uri=f"sqlite:///{tmp_sqlite_db}", query="SELECT * FROM test_data", engine=str(connect_using), # type: ignore[arg-type] schema_overrides=schema_overrides, ) + df_empty = pl.read_database_uri( + uri=f"sqlite:///{tmp_sqlite_db}", + query="SELECT * FROM test_data WHERE name LIKE '%polars%'", + engine=str(connect_using), # type: ignore[arg-type] + schema_overrides=schema_overrides, + ) elif "adbc" in os.environ["PYTEST_CURRENT_TEST"]: # externally instantiated adbc connections - with connect_using(test_db) as conn, conn.cursor(): + with connect_using(tmp_sqlite_db) as conn, conn.cursor(): df = pl.read_database( connection=conn, query="SELECT * FROM test_data", schema_overrides=schema_overrides, batch_size=batch_size, ) + df_empty = pl.read_database( + connection=conn, + query="SELECT * FROM test_data WHERE name LIKE '%polars%'", + schema_overrides=schema_overrides, + batch_size=batch_size, + ) else: # other user-supplied connections df = pl.read_database( - connection=connect_using(test_db), + connection=connect_using(tmp_sqlite_db), query="SELECT * FROM test_data WHERE name NOT LIKE '%polars%'", schema_overrides=schema_overrides, batch_size=batch_size, ) + df_empty = pl.read_database( + connection=connect_using(tmp_sqlite_db), + query="SELECT * FROM test_data WHERE name LIKE '%polars%'", + schema_overrides=schema_overrides, + batch_size=batch_size, + ) + # validate the expected query return (data and schema) assert df.schema == expected_dtypes assert df.shape == (2, 4) assert df["date"].to_list() == expected_dates + # note: 'cursor.description' is not reliable when no query + # data is returned, so no point comparing expected dtypes + assert df_empty.columns == ["id", "name", "value", "date"] + assert df_empty.shape == (0, 4) + assert df_empty["date"].to_list() == [] -def test_read_database_alchemy_selectable(tmp_path: Path) -> None: - # setup underlying test data - tmp_path.mkdir(exist_ok=True) - create_temp_sqlite_db(test_db := str(tmp_path / "test.db")) - conn = create_engine(f"sqlite:///{test_db}") - t = Table("test_data", MetaData(), autoload_with=conn) + +def test_read_database_alchemy_selectable(tmp_sqlite_db: Path) -> None: + # various flavours of alchemy connection + alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}") + alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)() + alchemy_conn: ConnectionOrCursor = alchemy_engine.connect() + + t = Table("test_data", MetaData(), autoload_with=alchemy_engine) # establish sqlalchemy "selectable" and validate usage selectable_query = select( @@ -363,21 +404,19 @@ def test_read_database_alchemy_selectable(tmp_path: Path) -> None: t.c.value, ).where(t.c.value < 0) - assert_frame_equal( - pl.read_database(selectable_query, connection=conn.connect()), - pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]}), - ) - + for conn in (alchemy_session, alchemy_engine, alchemy_conn): + assert_frame_equal( + pl.read_database(selectable_query, connection=conn), + pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]}), + ) -def test_read_database_parameterised(tmp_path: Path) -> None: - # setup underlying test data - tmp_path.mkdir(exist_ok=True) - create_temp_sqlite_db(test_db := str(tmp_path / "test.db")) +def test_read_database_parameterised(tmp_sqlite_db: Path) -> None: # raw cursor "execute" only takes positional params, alchemy cursor takes kwargs - raw_conn: ConnectionOrCursor = sqlite3.connect(test_db) - alchemy_conn: ConnectionOrCursor = create_engine(f"sqlite:///{test_db}").connect() - test_conns = (alchemy_conn, raw_conn) + alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}") + alchemy_conn: ConnectionOrCursor = alchemy_engine.connect() + alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)() + raw_conn: ConnectionOrCursor = sqlite3.connect(tmp_sqlite_db) # establish parameterised queries and validate usage query = """ @@ -385,22 +424,80 @@ def test_read_database_parameterised(tmp_path: Path) -> None: FROM test_data WHERE value < {n} """ + expected_frame = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]}) + for param, param_value in ( (":n", {"n": 0}), ("?", (0,)), ("?", [0]), ): - for conn in test_conns: + for conn in (alchemy_session, alchemy_engine, alchemy_conn, raw_conn): + if alchemy_session is conn and param == "?": + continue # alchemy session.execute() doesn't support positional params + assert_frame_equal( + expected_frame, pl.read_database( query.format(n=param), connection=conn, execute_options={"parameters": param_value}, ), - pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]}), ) +@pytest.mark.parametrize( + ("param", "param_value"), + [ + (":n", {"n": 0}), + ("?", (0,)), + ("?", [0]), + ], +) +@pytest.mark.skipif( + sys.version_info < (3, 9) or sys.platform == "win32", + reason="adbc_driver_sqlite not available on py3.8/windows", +) +def test_read_database_parameterised_uri( + param: str, param_value: Any, tmp_sqlite_db: Path +) -> None: + alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}") + uri = alchemy_engine.url.render_as_string(hide_password=False) + query = """ + SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value + FROM test_data + WHERE value < {n} + """ + expected_frame = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]}) + + for param, param_value in ( + (":n", pa.Table.from_pydict({"n": [0]})), + ("?", (0,)), + ("?", [0]), + ): + # test URI read method (adbc only) + assert_frame_equal( + expected_frame, + pl.read_database_uri( + query.format(n=param), + uri=uri, + engine="adbc", + execute_options={"parameters": param_value}, + ), + ) + + # no connectorx support for execute_options + with pytest.raises( + ValueError, + match="connectorx.*does not support.*execute_options", + ): + pl.read_database_uri( + query.format(n=":n"), + uri=uri, + engine="connectorx", + execute_options={"parameters": (":n", {"n": 0})}, + ) + + @pytest.mark.parametrize( ("driver", "batch_size", "iter_batches", "expected_call"), [ @@ -587,7 +684,6 @@ def test_read_database_exceptions( engine: DbReadEngine | None, execute_options: dict[str, Any] | None, kwargs: dict[str, Any] | None, - tmp_path: Path, ) -> None: if read_method == "read_database_uri": conn = f"{protocol}://test" if isinstance(protocol, str) else protocol @@ -621,3 +717,92 @@ def test_read_database_cx_credentials(uri: str) -> None: # can reasonably mitigate the issue. with pytest.raises(BaseException, match=r"fakedb://\*\*\*:\*\*\*@\w+"): pl.read_database_uri("SELECT * FROM data", uri=uri) + + +def test_database_infer_schema_length(tmp_sqlite_inference_db: Path) -> None: + # note: first row of this test database contains only NULL values + conn = sqlite3.connect(tmp_sqlite_inference_db) + for infer_len in (2, 100, None): + df = pl.read_database( + connection=conn, + query="SELECT * FROM test_data", + infer_schema_length=infer_len, + ) + assert df.schema == {"name": pl.String, "value": pl.Float64} + + with pytest.raises( + ComputeError, + match='could not append value: "foo" of type: str.*`infer_schema_length`', + ): + pl.read_database( + connection=conn, + query="SELECT * FROM test_data", + infer_schema_length=1, + ) + + +@pytest.mark.write_disk() +def test_read_kuzu_graph_database(tmp_path: Path, io_files_path: Path) -> None: + import kuzu + + tmp_path.mkdir(exist_ok=True) + if (kuzu_test_db := (tmp_path / "kuzu_test.db")).exists(): + kuzu_test_db.unlink() + + test_db = str(kuzu_test_db).replace("\\", "/") + + db = kuzu.Database(test_db) + conn = kuzu.Connection(db) + conn.execute("CREATE NODE TABLE User(name STRING, age UINT64, PRIMARY KEY (name))") + conn.execute("CREATE REL TABLE Follows(FROM User TO User, since INT64)") + + users = str(io_files_path / "graph-data" / "user.csv").replace("\\", "/") + follows = str(io_files_path / "graph-data" / "follows.csv").replace("\\", "/") + + conn.execute(f'COPY User FROM "{users}"') + conn.execute(f'COPY Follows FROM "{follows}"') + + # basic: single relation + df1 = pl.read_database( + query="MATCH (u:User) RETURN u.name, u.age", + connection=conn, + ) + assert_frame_equal( + df1, + pl.DataFrame( + { + "u.name": ["Adam", "Karissa", "Zhang", "Noura"], + "u.age": [30, 40, 50, 25], + }, + schema={"u.name": pl.Utf8, "u.age": pl.UInt64}, + ), + ) + + # join: connected edges/relations + df2 = pl.read_database( + query="MATCH (a:User)-[f:Follows]->(b:User) RETURN a.name, f.since, b.name", + connection=conn, + ) + assert_frame_equal( + df2, + pl.DataFrame( + { + "a.name": ["Adam", "Adam", "Karissa", "Zhang"], + "f.since": [2020, 2020, 2021, 2022], + "b.name": ["Karissa", "Zhang", "Zhang", "Noura"], + }, + schema={"a.name": pl.Utf8, "f.since": pl.Int64, "b.name": pl.Utf8}, + ), + ) + + # empty: no results for the given query + df3 = pl.read_database( + query="MATCH (a:User)-[f:Follows]->(b:User) WHERE a.name = '🔎️' RETURN a.name, f.since, b.name", + connection=conn, + ) + assert_frame_equal( + df3, + pl.DataFrame( + schema={"a.name": pl.Utf8, "f.since": pl.Int64, "b.name": pl.Utf8} + ), + ) diff --git a/py-polars/tests/unit/io/test_delta.py b/py-polars/tests/unit/io/test_delta.py index 46f097863a981..423c3b4eaa727 100644 --- a/py-polars/tests/unit/io/test_delta.py +++ b/py-polars/tests/unit/io/test_delta.py @@ -120,11 +120,15 @@ def test_write_delta(df: pl.DataFrame, tmp_path: Path) -> None: v1.write_delta(tmp_path) # Case: Overwrite with new version (version 1) - v1.write_delta(tmp_path, mode="overwrite", overwrite_schema=True) + v1.write_delta( + tmp_path, mode="overwrite", delta_write_options={"schema_mode": "overwrite"} + ) # Case: Error if schema contains unsupported columns with pytest.raises(TypeError): - df.write_delta(tmp_path, mode="overwrite", overwrite_schema=True) + df.write_delta( + tmp_path, mode="overwrite", delta_write_options={"schema_mode": "overwrite"} + ) partitioned_tbl_uri = (tmp_path / ".." / "partitioned_table").resolve() @@ -184,6 +188,17 @@ def test_write_delta(df: pl.DataFrame, tmp_path: Path) -> None: df_supported.write_delta(partitioned_tbl_uri, mode="overwrite") +@pytest.mark.write_disk() +def test_write_delta_overwrite_schema_deprecated( + df: pl.DataFrame, tmp_path: Path +) -> None: + df = df.select(pl.col(pl.Int64)) + with pytest.deprecated_call(): + df.write_delta(tmp_path, overwrite_schema=True) + result = pl.read_delta(str(tmp_path)) + assert_frame_equal(df, result) + + @pytest.mark.write_disk() @pytest.mark.parametrize( "series", @@ -360,18 +375,19 @@ def test_write_delta_with_schema_10540(tmp_path: Path) -> None: def test_write_delta_with_tz_in_df(expr: pl.Expr, tmp_path: Path) -> None: df = pl.select(expr) - pa_schema = pa.schema([("datetime", pa.timestamp("us"))]) + expected_dtype = pl.Datetime("us", "UTC") + expected = pl.select(expr.cast(expected_dtype)) df.write_delta(tmp_path, mode="append") # write second time because delta-rs also casts timestamp with tz to timestamp no tz df.write_delta(tmp_path, mode="append") + # Check schema of DeltaTable object tbl = DeltaTable(tmp_path) - assert pa_schema == tbl.schema().to_pyarrow() + assert tbl.schema().to_pyarrow() == expected.to_arrow().schema + # Check result result = pl.read_delta(str(tmp_path), version=0) - - expected = df.cast(pl.Datetime) assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/io/test_ipc.py b/py-polars/tests/unit/io/test_ipc.py index 41b47c135ec8a..8c03e3ee31363 100644 --- a/py-polars/tests/unit/io/test_ipc.py +++ b/py-polars/tests/unit/io/test_ipc.py @@ -85,15 +85,28 @@ def test_select_columns_from_file( @pytest.mark.parametrize("stream", [True, False]) def test_select_columns_from_buffer(stream: bool) -> None: - df = pl.DataFrame({"a": [1, 2, 3], "b": [True, False, True], "c": ["a", "b", "c"]}) - expected = pl.DataFrame({"b": [True, False, True], "c": ["a", "b", "c"]}) + df = pl.DataFrame( + { + "a": [1], + "b": [2], + "c": [3], + } + ) f = io.BytesIO() write_ipc(df, stream, f) f.seek(0) - read_df = read_ipc(stream, f, columns=["b", "c"], use_pyarrow=False) - assert_frame_equal(expected, read_df) + actual = read_ipc(stream, f, columns=["b", "c", "a"], use_pyarrow=False) + + expected = pl.DataFrame( + { + "b": [2], + "c": [3], + "a": [1], + } + ) + assert_frame_equal(expected, actual) @pytest.mark.parametrize("stream", [True, False]) diff --git a/py-polars/tests/unit/io/test_json.py b/py-polars/tests/unit/io/test_json.py index c937d0fe140e1..9acbb061a63c5 100644 --- a/py-polars/tests/unit/io/test_json.py +++ b/py-polars/tests/unit/io/test_json.py @@ -284,10 +284,10 @@ def test_write_json_duration() -> None: ) } ) - assert ( - df.write_json(row_oriented=True) - == '[{"a":"P1DT5362.939S"},{"a":"P1DT5362.890S"},{"a":"PT6020.836S"}]' - ) + + # we don't guarantee a format, just round-circling + value = str(df.write_json(row_oriented=True)) + assert value == """[{"a":"PT91762.939S"},{"a":"PT91762.89S"},{"a":"PT6020.836S"}]""" @pytest.mark.parametrize( @@ -297,7 +297,7 @@ def test_write_json_duration() -> None: ([["a", "b"], [None, None]], pl.Array(pl.Utf8, width=2)), ([[True, False, None], [None, None, None]], pl.Array(pl.Utf8, width=3)), ( - [[[1, 2, 3], [4, None]], None, [[None, None, 2]]], + [[[1, 2, 3], [4, None, 5]], None, [[None, None, 2]]], pl.List(pl.Array(pl.Int32(), width=3)), ), ( diff --git a/py-polars/tests/unit/io/test_lazy_count_star.py b/py-polars/tests/unit/io/test_lazy_count_star.py new file mode 100644 index 0000000000000..7ab69ad73aab4 --- /dev/null +++ b/py-polars/tests/unit/io/test_lazy_count_star.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pathlib import Path + +from tempfile import NamedTemporaryFile + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + + +@pytest.mark.parametrize( + ("path", "n_rows"), [("foods1.csv", 27), ("foods*.csv", 27 * 5)] +) +def test_count_csv(io_files_path: Path, path: str, n_rows: int) -> None: + lf = pl.scan_csv(io_files_path / path).select(pl.len()) + + expected = pl.DataFrame(pl.Series("len", [n_rows], dtype=pl.UInt32)) + + # Check if we are using our fast count star + assert "FAST COUNT(*)" in lf.explain() + assert_frame_equal(lf.collect(), expected) + + +@pytest.mark.write_disk() +def test_commented_csv() -> None: + csv_a = NamedTemporaryFile() + csv_a.write( + b""" +A,B +Gr1,A +Gr1,B +# comment line + """.strip() + ) + csv_a.seek(0) + + expected = pl.DataFrame(pl.Series("len", [2], dtype=pl.UInt32)) + lf = pl.scan_csv(csv_a.name, comment_prefix="#").select(pl.len()) + assert "FAST COUNT(*)" in lf.explain() + assert_frame_equal(lf.collect(), expected) + + +@pytest.mark.parametrize( + ("pattern", "n_rows"), [("small.parquet", 4), ("foods*.parquet", 54)] +) +def test_count_parquet(io_files_path: Path, pattern: str, n_rows: int) -> None: + lf = pl.scan_parquet(io_files_path / pattern).select(pl.len()) + + expected = pl.DataFrame(pl.Series("len", [n_rows], dtype=pl.UInt32)) + + # Check if we are using our fast count star + assert "FAST COUNT(*)" in lf.explain() + assert_frame_equal(lf.collect(), expected) + + +@pytest.mark.parametrize( + ("path", "n_rows"), [("foods1.ipc", 27), ("foods*.ipc", 27 * 2)] +) +def test_count_ipc(io_files_path: Path, path: str, n_rows: int) -> None: + lf = pl.scan_ipc(io_files_path / path).select(pl.len()) + + expected = pl.DataFrame(pl.Series("len", [n_rows], dtype=pl.UInt32)) + + # Check if we are using our fast count star + assert "FAST COUNT(*)" in lf.explain() + assert_frame_equal(lf.collect(), expected) diff --git a/py-polars/tests/unit/io/test_lazy_ipc.py b/py-polars/tests/unit/io/test_lazy_ipc.py index 8702e83af538e..ecf2a3e657c2e 100644 --- a/py-polars/tests/unit/io/test_lazy_ipc.py +++ b/py-polars/tests/unit/io/test_lazy_ipc.py @@ -1,10 +1,12 @@ from __future__ import annotations -from typing import TYPE_CHECKING +import sys +from typing import TYPE_CHECKING, Any import pytest import polars as pl +from polars.testing.asserts.frame import assert_frame_equal if TYPE_CHECKING: from pathlib import Path @@ -85,3 +87,30 @@ def test_ipc_list_arg(io_files_path: Path) -> None: assert df.shape == (54, 4) assert df.row(-1) == ("seafood", 194, 12.0, 1) assert df.row(0) == ("vegetables", 45, 0.5, 2) + + +@pytest.mark.skipif( + sys.platform == "win32", reason="object_store does not handle windows-style paths." +) +def test_scan_ipc_local_with_async( + capfd: Any, + monkeypatch: Any, + io_files_path: Path, +) -> None: + monkeypatch.setenv("POLARS_VERBOSE", "1") + monkeypatch.setenv("POLARS_FORCE_ASYNC", "1") + + assert_frame_equal( + pl.scan_ipc(io_files_path / "foods1.ipc").head(1).collect(), + pl.DataFrame( + { + "category": ["vegetables"], + "calories": [45], + "fats_g": [0.5], + "sugars_g": [2], + } + ), + ) + + captured = capfd.readouterr().err + assert "ASYNC READING FORCED" in captured diff --git a/py-polars/tests/unit/io/test_parquet.py b/py-polars/tests/unit/io/test_parquet.py index 006c245bda20b..89886a569617d 100644 --- a/py-polars/tests/unit/io/test_parquet.py +++ b/py-polars/tests/unit/io/test_parquet.py @@ -21,6 +21,13 @@ from polars.type_aliases import ParquetCompression +def test_round_trip(df: pl.DataFrame) -> None: + f = io.BytesIO() + df.write_parquet(f) + f.seek(0) + assert_frame_equal(pl.read_parquet(f), df) + + COMPRESSIONS = [ "lz4", "uncompressed", @@ -719,3 +726,61 @@ def test_parquet_rle_14333() -> None: pq.write_table(table, f, data_page_version="2.0") f.seek(0) assert pl.read_parquet(f)["a"].to_list() == vals + + +def test_parquet_rle_null_binary_read_14638() -> None: + df = pl.DataFrame({"x": [None]}, schema={"x": pl.String}) + + f = io.BytesIO() + df.write_parquet(f, use_pyarrow=True) + f.seek(0) + assert "RLE_DICTIONARY" in pq.read_metadata(f).row_group(0).column(0).encodings + f.seek(0) + assert_frame_equal(df, pl.read_parquet(f)) + + +def test_parquet_string_rle_encoding() -> None: + n = 3 + data = { + "id": ["abcdefgh"] * n, + } + + df = pl.DataFrame(data) + f = io.BytesIO() + df.write_parquet(f, use_pyarrow=False) + f.seek(0) + + assert ( + "RLE_DICTIONARY" + in pq.ParquetFile(f).metadata.to_dict()["row_groups"][0]["columns"][0][ + "encodings" + ] + ) + + +def test_sliced_dict_with_nulls_14904() -> None: + df = ( + pl.DataFrame({"x": [None, None]}) + .cast(pl.Categorical) + .with_columns(y=pl.concat_list("x")) + .slice(0, 1) + ) + test_round_trip(df) + + +def test_parquet_array_dtype() -> None: + df = pl.DataFrame({"x": [[1, 2, 3]]}) + df = df.cast({"x": pl.Array(pl.Int64, width=3)}) + test_round_trip(df) + + +@pytest.mark.write_disk() +def test_parquet_array_statistics() -> None: + df = pl.DataFrame({"a": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], "b": [1, 2, 3]}) + df.with_columns(a=pl.col("a").list.to_array(3)).lazy().filter( + pl.col("a") != [1, 2, 3] + ).collect() + df.with_columns(a=pl.col("a").list.to_array(3)).lazy().sink_parquet("test.parquet") + assert pl.scan_parquet("test.parquet").filter( + pl.col("a") != [1, 2, 3] + ).collect().to_dict(as_series=False) == {"a": [[4, 5, 6], [7, 8, 9]], "b": [2, 3]} diff --git a/py-polars/tests/unit/io/test_scan.py b/py-polars/tests/unit/io/test_scan.py new file mode 100644 index 0000000000000..1cccd07f67c1b --- /dev/null +++ b/py-polars/tests/unit/io/test_scan.py @@ -0,0 +1,195 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import pytest + +import polars as pl +from polars.testing.asserts.frame import assert_frame_equal + +if TYPE_CHECKING: + from pathlib import Path + + from polars.type_aliases import SchemaDict + + +@dataclass +class _RowIndex: + name: str = "index" + offset: int = 0 + + +def _scan( + file_path: Path, + schema: SchemaDict | None = None, + row_index: _RowIndex | None = None, +) -> pl.LazyFrame: + suffix = file_path.suffix + row_index_name = None if row_index is None else row_index.name + row_index_offset = 0 if row_index is None else row_index.offset + if suffix == ".ipc": + return pl.scan_ipc( + file_path, + row_index_name=row_index_name, + row_index_offset=row_index_offset, + ) + if suffix == ".parquet": + return pl.scan_parquet( + file_path, + row_index_name=row_index_name, + row_index_offset=row_index_offset, + ) + if suffix == ".csv": + return pl.scan_csv( + file_path, + schema=schema, + row_index_name=row_index_name, + row_index_offset=row_index_offset, + ) + msg = f"Unknown suffix {suffix}" + raise NotImplementedError(msg) + + +def _write(df: pl.DataFrame, file_path: Path) -> None: + suffix = file_path.suffix + if suffix == ".ipc": + return df.write_ipc(file_path) + if suffix == ".parquet": + return df.write_parquet(file_path) + if suffix == ".csv": + return df.write_csv(file_path) + msg = f"Unknown suffix {suffix}" + raise NotImplementedError(msg) + + +@pytest.fixture( + scope="session", + params=["csv", "ipc", "parquet"], +) +def data_file_extension(request: pytest.FixtureRequest) -> str: + return f".{request.param}" + + +@pytest.fixture(scope="session") +def session_tmp_dir(tmp_path_factory: pytest.TempPathFactory) -> Path: + return tmp_path_factory.mktemp("polars-test") + + +@dataclass +class _DataFile: + path: Path + df: pl.DataFrame + + +@pytest.fixture(scope="session") +def data_file_single(session_tmp_dir: Path, data_file_extension: str) -> _DataFile: + file_path = (session_tmp_dir / "data").with_suffix(data_file_extension) + df = pl.DataFrame( + { + "seq_int": range(10000), + "seq_str": [f"{x}" for x in range(10000)], + } + ) + _write(df, file_path) + return _DataFile(path=file_path, df=df) + + +@pytest.fixture(scope="session") +def data_file_glob(session_tmp_dir: Path, data_file_extension: str) -> _DataFile: + row_counts = [ + 100, 186, 95, 185, 90, 84, 115, 81, 87, 217, 126, 85, 98, 122, 129, 122, 1089, 82, + 234, 86, 93, 90, 91, 263, 87, 126, 86, 161, 191, 1368, 403, 192, 102, 98, 115, 81, + 111, 305, 92, 534, 431, 150, 90, 128, 152, 118, 127, 124, 229, 368, 81, + ] # fmt: skip + assert sum(row_counts) == 10000 + assert ( + len(row_counts) < 100 + ) # need to make sure we pad file names with enough zeros, otherwise the lexographical ordering of the file names is not what we want. + + df = pl.DataFrame( + { + "seq_int": range(10000), + "seq_str": [str(x) for x in range(10000)], + } + ) + + row_offset = 0 + for index, row_count in enumerate(row_counts): + file_path = (session_tmp_dir / f"data_{index:02}").with_suffix( + data_file_extension + ) + _write(df.slice(row_offset, row_count), file_path) + row_offset += row_count + return _DataFile( + path=(session_tmp_dir / "data_*").with_suffix(data_file_extension), df=df + ) + + +@pytest.fixture(scope="session", params=["single", "glob"]) +def data_file( + request: pytest.FixtureRequest, + data_file_single: _DataFile, + data_file_glob: _DataFile, +) -> _DataFile: + if request.param == "single": + return data_file_single + if request.param == "glob": + return data_file_glob + raise NotImplementedError() + + +@pytest.mark.write_disk() +def test_scan(data_file: _DataFile) -> None: + df = _scan(data_file.path, data_file.df.schema).collect() + assert_frame_equal(df, data_file.df) + + +@pytest.mark.write_disk() +def test_scan_with_limit(data_file: _DataFile) -> None: + df = _scan(data_file.path, data_file.df.schema).limit(100).collect() + assert_frame_equal( + df, + pl.DataFrame( + { + "seq_int": range(100), + "seq_str": [str(x) for x in range(100)], + } + ), + ) + + +@pytest.mark.write_disk() +def test_scan_with_row_index(data_file: _DataFile) -> None: + df = _scan(data_file.path, data_file.df.schema, row_index=_RowIndex()).collect() + assert_frame_equal( + df, + pl.DataFrame( + { + "index": range(10000), + "seq_int": range(10000), + "seq_str": [str(x) for x in range(10000)], + }, + schema_overrides={"index": pl.UInt32}, + ), + ) + + +@pytest.mark.write_disk() +def test_scan_with_row_index_and_predicate(data_file: _DataFile) -> None: + df = ( + _scan(data_file.path, data_file.df.schema, row_index=_RowIndex()) + .filter(pl.col("seq_int") % 2 == 0) + .collect() + ) + assert_frame_equal( + df, + pl.DataFrame( + { + "index": [2 * x for x in range(5000)], + "seq_int": [2 * x for x in range(5000)], + "seq_str": [str(2 * x) for x in range(5000)], + }, + schema_overrides={"index": pl.UInt32}, + ), + ) diff --git a/py-polars/tests/unit/io/test_spreadsheet.py b/py-polars/tests/unit/io/test_spreadsheet.py index 7bda53f482f14..085b50cb6a01a 100644 --- a/py-polars/tests/unit/io/test_spreadsheet.py +++ b/py-polars/tests/unit/io/test_spreadsheet.py @@ -861,7 +861,7 @@ def test_identify_workbook( if file_type == "xlsb": file_type = "xlsx" - # identify from BinaryIO + # identify from IO[bytes] with Path.open(spreadsheet_path, "rb") as f: assert _identify_workbook(f) == file_type diff --git a/py-polars/tests/unit/namespaces/test_datetime.py b/py-polars/tests/unit/namespaces/test_datetime.py index 5f2483d690b42..b1211b443b81e 100644 --- a/py-polars/tests/unit/namespaces/test_datetime.py +++ b/py-polars/tests/unit/namespaces/test_datetime.py @@ -1,6 +1,5 @@ from __future__ import annotations -import sys from datetime import date, datetime, time, timedelta from typing import TYPE_CHECKING @@ -8,19 +7,15 @@ import polars as pl from polars.datatypes import DTYPE_TEMPORAL_UNITS -from polars.dependencies import _ZONEINFO_AVAILABLE from polars.exceptions import ComputeError, InvalidOperationError from polars.testing import assert_frame_equal, assert_series_equal -if sys.version_info >= (3, 9): +if TYPE_CHECKING: from zoneinfo import ZoneInfo -elif _ZONEINFO_AVAILABLE: - # Import from submodule due to typing issue with backports.zoneinfo package: - # https://github.com/pganssle/zoneinfo/issues/125 - from backports.zoneinfo._zoneinfo import ZoneInfo -if TYPE_CHECKING: from polars.type_aliases import TemporalLiteral, TimeUnit +else: + from polars._utils.convert import string_to_zoneinfo as ZoneInfo @pytest.fixture() @@ -888,9 +883,9 @@ def test_offset_by_broadcasting() -> None: None, ], "d3": [ - datetime(2020, 10, 26, tzinfo=ZoneInfo(key="Europe/London")), - datetime(2020, 11, 4, tzinfo=ZoneInfo(key="Europe/London")), - datetime(2020, 10, 28, tzinfo=ZoneInfo(key="Europe/London")), + datetime(2020, 10, 26, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 11, 4, tzinfo=ZoneInfo("Europe/London")), + datetime(2020, 10, 28, tzinfo=ZoneInfo("Europe/London")), None, ], "d4": [ @@ -917,8 +912,8 @@ def test_offset_by_broadcasting() -> None: "d1": [datetime(2020, 11, 28), datetime(2021, 2, 5), None], "d2": [datetime(2021, 11, 25), datetime(2022, 2, 2), None], "d3": [ - datetime(2020, 10, 28, tzinfo=ZoneInfo(key="Europe/London")), - datetime(2021, 1, 5, tzinfo=ZoneInfo(key="Europe/London")), + datetime(2020, 10, 28, tzinfo=ZoneInfo("Europe/London")), + datetime(2021, 1, 5, tzinfo=ZoneInfo("Europe/London")), None, ], "d4": [datetime(2021, 11, 26).date(), datetime(2022, 2, 3).date(), None], diff --git a/py-polars/tests/unit/namespaces/test_strptime.py b/py-polars/tests/unit/namespaces/test_strptime.py index cba398d2d7d78..396c14daf1083 100644 --- a/py-polars/tests/unit/namespaces/test_strptime.py +++ b/py-polars/tests/unit/namespaces/test_strptime.py @@ -3,6 +3,7 @@ This method gets its own module due to its complexity. """ + from __future__ import annotations import contextlib @@ -20,7 +21,7 @@ from polars.type_aliases import PolarsTemporalType, TimeUnit else: - from polars.utils.convert import get_zoneinfo as ZoneInfo + from polars._utils.convert import string_to_zoneinfo as ZoneInfo def test_str_strptime() -> None: @@ -472,6 +473,24 @@ def test_to_datetime_ambiguous_or_non_existent() -> None: pl.Series(["2021-03-28 02:30"]).str.to_datetime( time_unit="us", time_zone="Europe/Warsaw" ) + with pytest.raises( + pl.ComputeError, + match="datetime '2021-03-28 02:30:00' is non-existent in time zone 'Europe/Warsaw'", + ): + pl.Series(["2021-03-28 02:30"]).str.to_datetime( + time_unit="us", + time_zone="Europe/Warsaw", + ambiguous="null", + ) + with pytest.raises( + pl.ComputeError, + match="datetime '2021-03-28 02:30:00' is non-existent in time zone 'Europe/Warsaw'", + ): + pl.Series(["2021-03-28 02:30"] * 2).str.to_datetime( + time_unit="us", + time_zone="Europe/Warsaw", + ambiguous=pl.Series(["null", "null"]), + ) @pytest.mark.parametrize( @@ -505,8 +524,8 @@ def test_to_datetime_tz_aware_strptime(ts: str, fmt: str, expected: datetime) -> def test_crossing_dst(format: str) -> None: ts = ["2021-03-27T23:59:59+01:00", "2021-03-28T23:59:59+02:00"] result = pl.Series(ts).str.to_datetime(format) - assert result[0] == datetime(2021, 3, 27, 22, 59, 59, tzinfo=ZoneInfo(key="UTC")) - assert result[1] == datetime(2021, 3, 28, 21, 59, 59, tzinfo=ZoneInfo(key="UTC")) + assert result[0] == datetime(2021, 3, 27, 22, 59, 59, tzinfo=ZoneInfo("UTC")) + assert result[1] == datetime(2021, 3, 28, 21, 59, 59, tzinfo=ZoneInfo("UTC")) @pytest.mark.parametrize("format", ["%+", "%Y-%m-%dT%H:%M:%S%z"]) diff --git a/py-polars/tests/unit/namespaces/test_struct.py b/py-polars/tests/unit/namespaces/test_struct.py index 01ce6e28b78b7..ee4806c00188b 100644 --- a/py-polars/tests/unit/namespaces/test_struct.py +++ b/py-polars/tests/unit/namespaces/test_struct.py @@ -60,7 +60,7 @@ def test_struct_json_encode_logical_type() -> None: } ).select(pl.col("a").struct.json_encode().alias("encoded")) assert df.to_dict(as_series=False) == { - "encoded": ['{"a":["1997-01-01"],"b":["2000-01-29 10:30:00"],"c":["P1DT25S"]}'] + "encoded": ['{"a":["1997-01-01"],"b":["2000-01-29 10:30:00"],"c":["PT86425S"]}'] } diff --git a/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py b/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py index 670299f889bf1..8f9740a42032a 100644 --- a/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py +++ b/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py @@ -12,10 +12,10 @@ import pytest import polars as pl +from polars._utils.udfs import _NUMPY_FUNCTIONS, BytecodeParser +from polars._utils.various import in_terminal_that_supports_colour from polars.exceptions import PolarsInefficientMapWarning from polars.testing import assert_frame_equal, assert_series_equal -from polars.utils.udfs import _NUMPY_FUNCTIONS, BytecodeParser -from polars.utils.various import in_terminal_that_supports_colour MY_CONSTANT = 3 MY_DICT = {0: "a", 1: "b", 2: "c", 3: "d", 4: "e"} @@ -133,7 +133,7 @@ '(pl.col("a") > 1) & ((pl.col("a") != 2) | ((pl.col("a") % 2) == 0)) & (pl.col("a") < 3)', ), # --------------------------------------------- - # string expr: case/cast ops + # string exprs # --------------------------------------------- ("b", "lambda x: str(x).title()", 'pl.col("b").cast(pl.String).str.to_titlecase()'), ( @@ -141,6 +141,21 @@ 'lambda x: x.lower() + ":" + x.upper() + ":" + x.title()', '(((pl.col("b").str.to_lowercase() + \':\') + pl.col("b").str.to_uppercase()) + \':\') + pl.col("b").str.to_titlecase()', ), + ( + "b", + "lambda x: x.strip().startswith('#')", + """pl.col("b").str.strip_chars().str.starts_with('#')""", + ), + ( + "b", + """lambda x: x.rstrip().endswith(('!','#','?','"'))""", + """pl.col("b").str.strip_chars_end().str.contains(r'(!|\\#|\\?|")$')""", + ), + ( + "b", + """lambda x: x.lstrip().startswith(('!','#','?',"'"))""", + """pl.col("b").str.strip_chars_start().str.contains(r"^(!|\\#|\\?|')")""", + ), # --------------------------------------------- # json expr: load/extract # --------------------------------------------- @@ -168,17 +183,30 @@ 'pl.col("d").str.to_datetime(format="%Y-%m-%d")', ), # --------------------------------------------- + # temporal attributes/methods + # --------------------------------------------- + ( + "f", + "lambda x: x.isoweekday()", + 'pl.col("f").dt.weekday()', + ), + ( + "f", + "lambda x: x.hour + x.minute + x.second", + '(pl.col("f").dt.hour() + pl.col("f").dt.minute()) + pl.col("f").dt.second()', + ), + # --------------------------------------------- # Bitwise shifts # --------------------------------------------- ( "a", "lambda x: (3 << (32-x)) & 3", - '(3*2**(32 - pl.col("a"))).cast(pl.Int64) & 3', + '(3 * 2**(32 - pl.col("a"))).cast(pl.Int64) & 3', ), ( "a", "lambda x: (x << 32) & 3", - '(pl.col("a")*2**32).cast(pl.Int64) & 3', + '(pl.col("a") * 2**32).cast(pl.Int64) & 3', ), ( "a", @@ -244,6 +272,11 @@ def test_parse_apply_functions(col: str, func: str, expr_repr: str) -> None: "c": ['{"a": 1}', '{"b": 2}', '{"c": 3}'], "d": ["2020-01-01", "2020-01-02", "2020-01-03"], "e": [1.5, 2.4, 3.1], + "f": [ + datetime(1999, 12, 31), + datetime(2024, 5, 6), + datetime(2077, 10, 20), + ], } ) result_frame = df.select( @@ -254,7 +287,11 @@ def test_parse_apply_functions(col: str, func: str, expr_repr: str) -> None: x=pl.col(col), y=pl.col(col).map_elements(eval(func)), ) - assert_frame_equal(result_frame, expected_frame) + assert_frame_equal( + result_frame, + expected_frame, + check_dtype=(".dt." not in suggested_expression), + ) @pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning") @@ -411,6 +448,15 @@ def test_expr_exact_warning_message() -> None: assert len(warnings) == 1 +def test_omit_implicit_bool() -> None: + parser = BytecodeParser( + function=lambda x: x and x and x.date(), + map_target="expr", + ) + suggested_expression = parser.to_expression("d") + assert suggested_expression == 'pl.col("d").dt.date()' + + def test_partial_functions_13523() -> None: def plus(value, amount: int): # type: ignore[no-untyped-def] return value + amount diff --git a/py-polars/tests/unit/operations/rolling/test_rolling.py b/py-polars/tests/unit/operations/rolling/test_rolling.py index e30cc160f505a..6ac4a5937a129 100644 --- a/py-polars/tests/unit/operations/rolling/test_rolling.py +++ b/py-polars/tests/unit/operations/rolling/test_rolling.py @@ -12,7 +12,7 @@ from polars.testing import assert_frame_equal, assert_series_equal if TYPE_CHECKING: - from polars.type_aliases import ClosedInterval, TimeUnit + from polars.type_aliases import ClosedInterval, PolarsDataType, TimeUnit @pytest.fixture() @@ -188,18 +188,21 @@ def test_rolling_skew() -> None: @pytest.mark.parametrize("time_zone", [None, "US/Central"]) @pytest.mark.parametrize( - ("rolling_fn", "expected_values"), + ("rolling_fn", "expected_values", "expected_dtype"), [ - ("rolling_mean", [None, 1.0, 2.0, 3.0, 4.0, 5.0]), - ("rolling_sum", [None, 1, 2, 3, 4, 5]), - ("rolling_min", [None, 1, 2, 3, 4, 5]), - ("rolling_max", [None, 1, 2, 3, 4, 5]), - ("rolling_std", [None, 0.0, 0.0, 0.0, 0.0, 0.0]), - ("rolling_var", [None, 0.0, 0.0, 0.0, 0.0, 0.0]), + ("rolling_mean", [None, 1.0, 2.0, 3.0, 4.0, 5.0], pl.Float64), + ("rolling_sum", [None, 1, 2, 3, 4, 5], pl.Int64), + ("rolling_min", [None, 1, 2, 3, 4, 5], pl.Int64), + ("rolling_max", [None, 1, 2, 3, 4, 5], pl.Int64), + ("rolling_std", [None, None, None, None, None, None], pl.Float64), + ("rolling_var", [None, None, None, None, None, None], pl.Float64), ], ) def test_rolling_crossing_dst( - time_zone: str | None, rolling_fn: str, expected_values: list[int | None | float] + time_zone: str | None, + rolling_fn: str, + expected_values: list[int | None | float], + expected_dtype: PolarsDataType, ) -> None: ts = pl.datetime_range( datetime(2021, 11, 5), datetime(2021, 11, 10), "1d", time_zone="UTC", eager=True @@ -208,7 +211,9 @@ def test_rolling_crossing_dst( result = df.with_columns( getattr(pl.col("value"), rolling_fn)("1d", by="ts", closed="left") ) - expected = pl.DataFrame({"ts": ts, "value": expected_values}) + expected = pl.DataFrame( + {"ts": ts, "value": expected_values}, schema_overrides={"value": expected_dtype} + ) assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/operations/test_aggregations.py b/py-polars/tests/unit/operations/test_aggregations.py index ad588032d8d00..3c7c40e984b51 100644 --- a/py-polars/tests/unit/operations/test_aggregations.py +++ b/py-polars/tests/unit/operations/test_aggregations.py @@ -398,3 +398,44 @@ def test_agg_filter_over_empty_df_13610() -> None: out = df.group_by("a").agg(pl.col("b").filter(pl.col("b").shift())) expected = pl.DataFrame(schema={"a": pl.Int64, "b": pl.List(pl.Boolean)}) assert_frame_equal(out, expected) + + +@pytest.mark.slow() +def test_agg_empty_sum_after_filter_14734() -> None: + f = ( + pl.DataFrame({"a": [1, 2], "b": [1, 2]}) + .lazy() + .group_by("a") + .agg(pl.col("b").filter(pl.lit(False)).sum()) + .collect + ) + + last = f() + + # We need both possible output orders, which should happen within + # 1000 iterations (during testing it usually happens within 10). + limit = 1000 + i = 0 + while (curr := f()).equals(last): + i += 1 + assert i != limit + + expect = pl.Series("b", [0, 0]).to_frame() + assert_frame_equal(expect, last.select("b")) + assert_frame_equal(expect, curr.select("b")) + + +@pytest.mark.slow() +def test_grouping_hash_14749() -> None: + n_groups = 251 + rows_per_group = 4 + assert ( + pl.DataFrame( + { + "grp": np.repeat(np.arange(n_groups), rows_per_group), + "x": np.tile(np.arange(rows_per_group), n_groups), + } + ) + .select(pl.col("x").max().over("grp"))["x"] + .value_counts() + ).to_dict(as_series=False) == {"x": [3], "count": [1004]} diff --git a/py-polars/tests/unit/operations/test_cast.py b/py-polars/tests/unit/operations/test_cast.py index 4040c112cb246..73722c9660f5a 100644 --- a/py-polars/tests/unit/operations/test_cast.py +++ b/py-polars/tests/unit/operations/test_cast.py @@ -1,18 +1,19 @@ from __future__ import annotations from datetime import date, datetime, time, timedelta +from decimal import Decimal from typing import TYPE_CHECKING, Any import pytest import polars as pl -from polars.testing import assert_frame_equal -from polars.testing.asserts.series import assert_series_equal -from polars.utils.convert import ( +from polars._utils.convert import ( MS_PER_SECOND, NS_PER_SECOND, US_PER_SECOND, ) +from polars.testing import assert_frame_equal +from polars.testing.asserts.series import assert_series_equal if TYPE_CHECKING: from polars import PolarsDataType @@ -620,3 +621,22 @@ def test_cast_time_to_date() -> None: msg = "cannot cast `Time` to `Date`" with pytest.raises(pl.ComputeError, match=msg): s.cast(pl.Date) + + +def test_cast_decimal() -> None: + s = pl.Series("s", [Decimal(0), Decimal(1.5), Decimal(-1.5)]) + assert_series_equal(s.cast(pl.Boolean), pl.Series("s", [False, True, True])) + + df = s.to_frame() + assert_frame_equal( + df.select(pl.col("s").cast(pl.Boolean)), + pl.DataFrame({"s": [False, True, True]}), + ) + + +def test_cast_array_to_different_width() -> None: + s = pl.Series([[1, 2], [3, 4]], dtype=pl.Array(pl.Int8, 2)) + with pytest.raises( + pl.InvalidOperationError, match="cannot cast Array to a different width" + ): + s.cast(pl.Array(pl.Int16, 3)) diff --git a/py-polars/tests/unit/operations/test_ewm.py b/py-polars/tests/unit/operations/test_ewm.py new file mode 100644 index 0000000000000..faf0750c689bd --- /dev/null +++ b/py-polars/tests/unit/operations/test_ewm.py @@ -0,0 +1,299 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np +import pytest +from hypothesis import given +from hypothesis.strategies import booleans, floats + +import polars as pl +from polars.expr.expr import _prepare_alpha +from polars.testing import assert_series_equal +from polars.testing.parametric import series + + +def test_ewm_mean() -> None: + s = pl.Series([2, 5, 3]) + + expected = pl.Series([2.0, 4.0, 3.4285714285714284]) + assert_series_equal(s.ewm_mean(alpha=0.5, adjust=True, ignore_nulls=True), expected) + assert_series_equal( + s.ewm_mean(alpha=0.5, adjust=True, ignore_nulls=False), expected + ) + + expected = pl.Series([2.0, 3.8, 3.421053]) + assert_series_equal(s.ewm_mean(com=2.0, adjust=True, ignore_nulls=True), expected) + assert_series_equal(s.ewm_mean(com=2.0, adjust=True, ignore_nulls=False), expected) + + expected = pl.Series([2.0, 3.5, 3.25]) + assert_series_equal( + s.ewm_mean(alpha=0.5, adjust=False, ignore_nulls=True), expected + ) + assert_series_equal( + s.ewm_mean(alpha=0.5, adjust=False, ignore_nulls=False), expected + ) + + s = pl.Series([2, 3, 5, 7, 4]) + + expected = pl.Series([None, 2.666667, 4.0, 5.6, 4.774194]) + assert_series_equal( + s.ewm_mean(alpha=0.5, adjust=True, min_periods=2, ignore_nulls=True), expected + ) + assert_series_equal( + s.ewm_mean(alpha=0.5, adjust=True, min_periods=2, ignore_nulls=False), expected + ) + + expected = pl.Series([None, None, 4.0, 5.6, 4.774194]) + assert_series_equal( + s.ewm_mean(alpha=0.5, adjust=True, min_periods=3, ignore_nulls=True), expected + ) + assert_series_equal( + s.ewm_mean(alpha=0.5, adjust=True, min_periods=3, ignore_nulls=False), expected + ) + + s = pl.Series([None, 1.0, 5.0, 7.0, None, 2.0, 5.0, 4]) + + expected = pl.Series( + [ + None, + 1.0, + 3.6666666666666665, + 5.571428571428571, + 5.571428571428571, + 3.6666666666666665, + 4.354838709677419, + 4.174603174603175, + ], + ) + assert_series_equal(s.ewm_mean(alpha=0.5, adjust=True, ignore_nulls=True), expected) + expected = pl.Series( + [ + None, + 1.0, + 3.666666666666667, + 5.571428571428571, + 5.571428571428571, + 3.08695652173913, + 4.2, + 4.092436974789916, + ] + ) + assert_series_equal( + s.ewm_mean(alpha=0.5, adjust=True, ignore_nulls=False), expected + ) + + expected = pl.Series([None, 1.0, 3.0, 5.0, 5.0, 3.5, 4.25, 4.125]) + assert_series_equal( + s.ewm_mean(alpha=0.5, adjust=False, ignore_nulls=True), expected + ) + + expected = pl.Series([None, 1.0, 3.0, 5.0, 5.0, 3.0, 4.0, 4.0]) + assert_series_equal( + s.ewm_mean(alpha=0.5, adjust=False, ignore_nulls=False), expected + ) + + +def test_ewm_mean_leading_nulls() -> None: + for min_periods in [1, 2, 3]: + assert ( + pl.Series([1, 2, 3, 4]) + .ewm_mean(com=3, min_periods=min_periods, ignore_nulls=False) + .null_count() + == min_periods - 1 + ) + assert pl.Series([None, 1.0, 1.0, 1.0]).ewm_mean( + alpha=0.5, min_periods=1, ignore_nulls=True + ).to_list() == [None, 1.0, 1.0, 1.0] + assert pl.Series([None, 1.0, 1.0, 1.0]).ewm_mean( + alpha=0.5, min_periods=2, ignore_nulls=True + ).to_list() == [None, None, 1.0, 1.0] + + +def test_ewm_mean_min_periods() -> None: + series = pl.Series([1.0, None, None, None]) + + ewm_mean = series.ewm_mean(alpha=0.5, min_periods=1, ignore_nulls=True) + assert ewm_mean.to_list() == [1.0, 1.0, 1.0, 1.0] + ewm_mean = series.ewm_mean(alpha=0.5, min_periods=2, ignore_nulls=True) + assert ewm_mean.to_list() == [None, None, None, None] + + series = pl.Series([1.0, None, 2.0, None, 3.0]) + + ewm_mean = series.ewm_mean(alpha=0.5, min_periods=1, ignore_nulls=True) + assert_series_equal( + ewm_mean, + pl.Series( + [ + 1.0, + 1.0, + 1.6666666666666665, + 1.6666666666666665, + 2.4285714285714284, + ] + ), + ) + ewm_mean = series.ewm_mean(alpha=0.5, min_periods=2, ignore_nulls=True) + assert_series_equal( + ewm_mean, + pl.Series( + [ + None, + None, + 1.6666666666666665, + 1.6666666666666665, + 2.4285714285714284, + ] + ), + ) + + +def test_ewm_std_var() -> None: + series = pl.Series("a", [2, 5, 3]) + + var = series.ewm_var(alpha=0.5, ignore_nulls=False) + std = series.ewm_std(alpha=0.5, ignore_nulls=False) + + assert np.allclose(var, std**2, rtol=1e-16) + + +def test_ewm_param_validation() -> None: + s = pl.Series("values", range(10)) + + with pytest.raises(ValueError, match="mutually exclusive"): + s.ewm_std(com=0.5, alpha=0.5, ignore_nulls=False) + + with pytest.raises(ValueError, match="mutually exclusive"): + s.ewm_mean(span=1.5, half_life=0.75, ignore_nulls=False) + + with pytest.raises(ValueError, match="mutually exclusive"): + s.ewm_var(alpha=0.5, span=1.5, ignore_nulls=False) + + with pytest.raises(ValueError, match="require `com` >= 0"): + s.ewm_std(com=-0.5, ignore_nulls=False) + + with pytest.raises(ValueError, match="require `span` >= 1"): + s.ewm_mean(span=0.5, ignore_nulls=False) + + with pytest.raises(ValueError, match="require `half_life` > 0"): + s.ewm_var(half_life=0, ignore_nulls=False) + + for alpha in (-0.5, -0.0000001, 0.0, 1.0000001, 1.5): + with pytest.raises(ValueError, match="require 0 < `alpha` <= 1"): + s.ewm_std(alpha=alpha, ignore_nulls=False) + + +# https://github.com/pola-rs/polars/issues/4951 +def test_ewm_with_multiple_chunks() -> None: + df0 = pl.DataFrame( + data=[ + ("w", 6.0, 1.0), + ("x", 5.0, 2.0), + ("y", 4.0, 3.0), + ("z", 3.0, 4.0), + ], + schema=["a", "b", "c"], + ).with_columns( + [ + pl.col(pl.Float64).log().diff().name.prefix("ld_"), + ] + ) + assert df0.n_chunks() == 1 + + # NOTE: We aren't testing whether `select` creates two chunks; + # we just need two chunks to properly test `ewm_mean` + df1 = df0.select(["ld_b", "ld_c"]) + assert df1.n_chunks() == 2 + + ewm_std = df1.with_columns( + pl.all().ewm_std(com=20, ignore_nulls=False).name.prefix("ewm_"), + ) + assert ewm_std.null_count().sum_horizontal()[0] == 4 + + +def alpha_guard(**decay_param: float) -> bool: + """Protects against unnecessary noise in small number regime.""" + if not next(iter(decay_param.values())): + return True + alpha = _prepare_alpha(**decay_param) + return ((1 - alpha) if round(alpha) else alpha) > 1e-6 + + +@given( + s=series( + min_size=4, + dtype=pl.Float64, + null_probability=0.05, + strategy=floats(min_value=-1e8, max_value=1e8), + ), + half_life=floats(min_value=0, max_value=4, exclude_min=True).filter( + lambda x: alpha_guard(half_life=x) + ), + com=floats(min_value=0, max_value=99).filter(lambda x: alpha_guard(com=x)), + span=floats(min_value=1, max_value=10).filter(lambda x: alpha_guard(span=x)), + ignore_nulls=booleans(), + adjust=booleans(), + bias=booleans(), +) +def test_ewm_methods( + s: pl.Series, + com: float | None, + span: float | None, + half_life: float | None, + ignore_nulls: bool, + adjust: bool, + bias: bool, +) -> None: + # validate a large set of varied EWM calculations + for decay_param in [{"com": com}, {"span": span}, {"half_life": half_life}]: + alpha = _prepare_alpha(**decay_param) + + # convert parametrically-generated series to pandas, then use that as a + # reference implementation for comparison (after normalising NaN/None) + p = s.to_pandas() + + # note: skip min_periods < 2, due to pandas-side inconsistency: + # https://github.com/pola-rs/polars/issues/5006#issuecomment-1259477178 + for mp in range(2, len(s), len(s) // 3): + # consolidate ewm parameters + pl_params: dict[str, Any] = { + "min_periods": mp, + "adjust": adjust, + "ignore_nulls": ignore_nulls, + } + pl_params.update(decay_param) + pd_params = pl_params.copy() + if "half_life" in pl_params: + pd_params["halflife"] = pd_params.pop("half_life") + if "ignore_nulls" in pl_params: + pd_params["ignore_na"] = pd_params.pop("ignore_nulls") + + # mean: + ewm_mean_pl = s.ewm_mean(**pl_params).fill_nan(None) + ewm_mean_pd = pl.Series(p.ewm(**pd_params).mean()) + if alpha == 1: + # apply fill-forward to nulls to match pandas + # https://github.com/pola-rs/polars/pull/5011#issuecomment-1262318124 + ewm_mean_pl = ewm_mean_pl.fill_null(strategy="forward") + + assert_series_equal(ewm_mean_pl, ewm_mean_pd, atol=1e-07) + + # std: + ewm_std_pl = s.ewm_std(bias=bias, **pl_params).fill_nan(None) + ewm_std_pd = pl.Series(p.ewm(**pd_params).std(bias=bias)) + assert_series_equal(ewm_std_pl, ewm_std_pd, atol=1e-07) + + # var: + ewm_var_pl = s.ewm_var(bias=bias, **pl_params).fill_nan(None) + ewm_var_pd = pl.Series(p.ewm(**pd_params).var(bias=bias)) + assert_series_equal(ewm_var_pl, ewm_var_pd, atol=1e-07) + + +def test_ewm_ignore_nulls_deprecation() -> None: + s = pl.Series([1, None, 3]) + with pytest.deprecated_call(): + s.ewm_mean(com=1.0) + with pytest.deprecated_call(): + s.ewm_std(com=1.0) + with pytest.deprecated_call(): + s.ewm_var(com=1.0) diff --git a/py-polars/tests/unit/operations/test_filter.py b/py-polars/tests/unit/operations/test_filter.py index 533eadd373396..c00dcbdc4786d 100644 --- a/py-polars/tests/unit/operations/test_filter.py +++ b/py-polars/tests/unit/operations/test_filter.py @@ -1,10 +1,11 @@ from datetime import datetime +import numpy as np import pytest import polars as pl from polars import PolarsDataType -from polars.testing import assert_frame_equal +from polars.testing import assert_frame_equal, assert_series_equal def test_simplify_expression_lit_true_4376() -> None: @@ -239,3 +240,21 @@ def test_filter_logical_type_13194() -> None: }, ) assert_frame_equal(df, expected_df) + + +@pytest.mark.slow() +@pytest.mark.parametrize( + "dtype", [pl.Boolean, pl.Int8, pl.Int16, pl.Int32, pl.Int64, pl.String] +) +@pytest.mark.parametrize("size", list(range(64)) + [100, 1000, 10000]) +@pytest.mark.parametrize("selectivity", [0.0, 0.01, 0.1, 0.5, 0.9, 0.99, 1.0 + 1e-6]) +def test_filter(dtype: PolarsDataType, size: int, selectivity: float) -> None: + rng = np.random.Generator(np.random.PCG64(size * 100 + int(100 * selectivity))) + np_payload = rng.uniform(size=size) * 100.0 + np_mask = rng.uniform(size=size) < selectivity + payload = pl.Series(np_payload).cast(dtype) + mask = pl.Series(np_mask, dtype=pl.Boolean) + + reference = pl.Series(np_payload[np_mask]).cast(dtype) + result = payload.filter(mask) + assert_series_equal(reference, result) diff --git a/py-polars/tests/unit/operations/test_group_by.py b/py-polars/tests/unit/operations/test_group_by.py index 97ebba213cd28..308598b612359 100644 --- a/py-polars/tests/unit/operations/test_group_by.py +++ b/py-polars/tests/unit/operations/test_group_by.py @@ -949,3 +949,21 @@ def test_group_by_with_null() -> None: ) output = df.group_by(["a", "b"], maintain_order=True).agg(pl.col("c")) assert_frame_equal(expected, output) + + +def test_partitioned_group_by_14954(monkeypatch: Any) -> None: + monkeypatch.setenv("POLARS_FORCE_PARTITION", "1") + assert ( + pl.DataFrame({"a": range(20)}) + .select(pl.col("a") % 2) + .group_by("a") + .agg( + (pl.col("a") > 1000).alias("a > 1000"), + ) + ).sort("a").to_dict(as_series=False) == { + "a": [0, 1], + "a > 1000": [ + [False, False, False, False, False, False, False, False, False, False], + [False, False, False, False, False, False, False, False, False, False], + ], + } diff --git a/py-polars/tests/unit/operations/test_group_by_dynamic.py b/py-polars/tests/unit/operations/test_group_by_dynamic.py index 9404b22ea52a9..9fdcebfad510b 100644 --- a/py-polars/tests/unit/operations/test_group_by_dynamic.py +++ b/py-polars/tests/unit/operations/test_group_by_dynamic.py @@ -1,6 +1,5 @@ from __future__ import annotations -import sys from datetime import date, datetime, timedelta from typing import TYPE_CHECKING, Any @@ -10,15 +9,12 @@ import polars as pl from polars.testing import assert_frame_equal -if sys.version_info >= (3, 9): +if TYPE_CHECKING: from zoneinfo import ZoneInfo -else: - # Import from submodule due to typing issue with backports.zoneinfo package: - # https://github.com/pganssle/zoneinfo/issues/125 - from backports.zoneinfo._zoneinfo import ZoneInfo -if TYPE_CHECKING: from polars.type_aliases import Label, StartBy +else: + from polars._utils.convert import string_to_zoneinfo as ZoneInfo @pytest.mark.parametrize( diff --git a/py-polars/tests/unit/operations/test_is_in.py b/py-polars/tests/unit/operations/test_is_in.py index 8805e47f71041..e2a0eb5fc4c89 100644 --- a/py-polars/tests/unit/operations/test_is_in.py +++ b/py-polars/tests/unit/operations/test_is_in.py @@ -312,3 +312,85 @@ def test_is_in_with_wildcard_13809() -> None: out = pl.DataFrame({"A": ["B"]}).select(pl.all().is_in(["C"])) expected = pl.DataFrame({"A": [False]}) assert_frame_equal(out, expected) + + +@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum(["a", "b", "c", "d"])]) +def test_cat_is_in_from_str(dtype: pl.DataType) -> None: + s = pl.Series(["c", "c", "b"], dtype=dtype) + + # test local + assert_series_equal( + pl.Series(["a", "d", "e", "b"]).is_in(s), + pl.Series([False, False, False, True]), + ) + + +@pytest.mark.parametrize("dtype", [pl.Categorical, pl.Enum(["a", "b", "c", "d"])]) +def test_cat_list_is_in_from_cat(dtype: pl.DataType) -> None: + df = pl.DataFrame( + [ + (["a", "b"], "c"), + (["a", "b"], "a"), + (["a", None], None), + (["a", "c"], None), + (["a"], "d"), + ], + schema={"li": pl.List(dtype), "x": dtype}, + ) + res = df.select(pl.col("li").list.contains(pl.col("x"))) + expected_df = pl.DataFrame({"li": [False, True, True, False, False]}) + assert_frame_equal(res, expected_df) + + +@pytest.mark.parametrize( + ("val", "expected"), + [ + ("b", [True, False, False, None, True]), + (None, [False, False, True, None, False]), + ("e", [False, False, False, None, False]), + ], +) +def test_cat_list_is_in_from_cat_single(val: str | None, expected: list[bool]) -> None: + df = pl.Series( + "li", + [["a", "b"], ["a", "c"], ["a", None], None, ["b"]], + dtype=pl.List(pl.Categorical), + ).to_frame() + res = df.select(pl.col("li").list.contains(pl.lit(val, dtype=pl.Categorical))) + expected_df = pl.DataFrame({"li": expected}) + assert_frame_equal(res, expected_df) + + +def test_cat_list_is_in_from_str() -> None: + df = pl.DataFrame( + [ + (["a", "b"], "c"), + (["a", "b"], "a"), + (["a", None], None), + (["a", "c"], None), + (["a"], "d"), + ], + schema={"li": pl.List(pl.Categorical), "x": pl.String}, + ) + res = df.select(pl.col("li").list.contains(pl.col("x"))) + expected_df = pl.DataFrame({"li": [False, True, True, False, False]}) + assert_frame_equal(res, expected_df) + + +@pytest.mark.parametrize( + ("val", "expected"), + [ + ("b", [True, False, False, None, True]), + (None, [False, False, True, None, False]), + ("e", [False, False, False, None, False]), + ], +) +def test_cat_list_is_in_from_single_str(val: str | None, expected: list[bool]) -> None: + df = pl.Series( + "li", + [["a", "b"], ["a", "c"], ["a", None], None, ["b"]], + dtype=pl.List(pl.Categorical), + ).to_frame() + res = df.select(pl.col("li").list.contains(pl.lit(val, dtype=pl.String))) + expected_df = pl.DataFrame({"li": expected}) + assert_frame_equal(res, expected_df) diff --git a/py-polars/tests/unit/operations/test_join.py b/py-polars/tests/unit/operations/test_join.py index 97b29dd6aeedb..0656dfd8e9396 100644 --- a/py-polars/tests/unit/operations/test_join.py +++ b/py-polars/tests/unit/operations/test_join.py @@ -789,3 +789,17 @@ def test_join_on_nth_error() -> None: pl.ComputeError, match="nth column selection not supported at this point" ): df.join(df2, on=pl.first()) + + +def test_join_results_in_duplicate_names() -> None: + lhs = pl.DataFrame( + { + "a": [1, 2, 3], + "b": [4, 5, 6], + "c": [1, 2, 3], + "c_right": [1, 2, 3], + } + ) + rhs = lhs.clone() + with pytest.raises(pl.DuplicateError, match="'c_right' already exists"): + lhs.join(rhs, on=["a", "b"], how="left") diff --git a/py-polars/tests/unit/operations/test_pivot.py b/py-polars/tests/unit/operations/test_pivot.py index 5d0b4a6e69f16..7b63963cc559c 100644 --- a/py-polars/tests/unit/operations/test_pivot.py +++ b/py-polars/tests/unit/operations/test_pivot.py @@ -18,8 +18,8 @@ def test_pivot() -> None: df = pl.DataFrame( { "foo": ["A", "A", "B", "B", "C"], - "N": [1, 2, 2, 4, 2], "bar": ["k", "l", "m", "n", "o"], + "N": [1, 2, 2, 4, 2], } ) result = df.pivot(index="foo", columns="bar", values="N", aggregate_function=None) @@ -35,6 +35,35 @@ def test_pivot() -> None: assert_frame_equal(result, expected) +def test_pivot_no_values() -> None: + df = pl.DataFrame( + { + "foo": ["A", "A", "B", "B", "C"], + "bar": ["k", "l", "m", "n", "o"], + "N1": [1, 2, 2, 4, 2], + "N2": [1, 2, 2, 4, 2], + } + ) + result = df.pivot(index="foo", columns="bar", values=None, aggregate_function=None) + expected = pl.DataFrame( + { + "foo": ["A", "B", "C"], + "N1_bar_k": [1, None, None], + "N1_bar_l": [2, None, None], + "N1_bar_m": [None, 2, None], + "N1_bar_n": [None, 4, None], + "N1_bar_o": [None, None, 2], + "N2_bar_k": [1, None, None], + "N2_bar_l": [2, None, None], + "N2_bar_m": [None, 2, None], + "N2_bar_n": [None, 4, None], + "N2_bar_o": [None, None, 2], + } + ) + + assert_frame_equal(result, expected) + + def test_pivot_list() -> None: df = pl.DataFrame({"a": [1, 2, 3], "b": [[1, 1], [2, 2], [3, 3]]}) @@ -77,7 +106,7 @@ def test_pivot_aggregate(agg_fn: PivotAgg, expected_rows: list[tuple[Any]]) -> N } ) result = df.pivot( - values="c", index="b", columns="a", aggregate_function=agg_fn, sort_columns=True + index="b", columns="a", values="c", aggregate_function=agg_fn, sort_columns=True ) assert result.rows() == expected_rows @@ -110,12 +139,12 @@ def test_pivot_categorical_index() -> None: schema=[("A", pl.Categorical), ("B", pl.Categorical)], ) - result = df.pivot(values="B", index=["A"], columns="B", aggregate_function="len") + result = df.pivot(index=["A"], columns="B", values="B", aggregate_function="len") expected = {"A": ["Fire", "Water"], "Car": [1, 2], "Ship": [1, None]} assert result.to_dict(as_series=False) == expected # test expression dispatch - result = df.pivot(values="B", index=["A"], columns="B", aggregate_function=pl.len()) + result = df.pivot(index=["A"], columns="B", values="B", aggregate_function=pl.len()) assert result.to_dict(as_series=False) == expected df = pl.DataFrame( @@ -127,7 +156,7 @@ def test_pivot_categorical_index() -> None: schema=[("A", pl.Categorical), ("B", pl.Categorical), ("C", pl.Categorical)], ) result = df.pivot( - values="B", index=["A", "C"], columns="B", aggregate_function="len" + index=["A", "C"], columns="B", values="B", aggregate_function="len" ) expected = { "A": ["Fire", "Water"], @@ -150,17 +179,17 @@ def test_pivot_multiple_values_column_names_5116() -> None: with pytest.raises(ComputeError, match="found multiple elements in the same group"): df.pivot( - values=["x1", "x2"], index="c1", columns="c2", + values=["x1", "x2"], separator="|", aggregate_function=None, ) result = df.pivot( - values=["x1", "x2"], index="c1", columns="c2", + values=["x1", "x2"], separator="|", aggregate_function="first", ) @@ -185,9 +214,9 @@ def test_pivot_duplicate_names_7731() -> None: } ) result = df.pivot( - values=cs.integer(), index=cs.float(), columns=cs.string(), + values=cs.integer(), aggregate_function="first", ).to_dict(as_series=False) expected = { @@ -202,7 +231,7 @@ def test_pivot_duplicate_names_7731() -> None: def test_pivot_duplicate_names_11663() -> None: df = pl.DataFrame({"a": [1, 2], "b": [1, 2], "c": ["x", "x"], "d": ["x", "y"]}) - result = df.pivot(values="a", index="b", columns=["c", "d"]).to_dict( + result = df.pivot(index="b", columns=["c", "d"], values="a").to_dict( as_series=False ) expected = {"b": [1, 2], '{"x","x"}': [1, None], '{"x","y"}': [None, 2]} @@ -220,7 +249,7 @@ def test_pivot_multiple_columns_12407() -> None: } ) result = df.pivot( - values=["a"], index="b", columns=["c", "e"], aggregate_function="len" + index="b", columns=["c", "e"], values=["a"], aggregate_function="len" ).to_dict(as_series=False) expected = {"b": ["a", "b"], '{"s","x"}': [1, None], '{"f","y"}': [None, 1]} assert result == expected @@ -254,7 +283,7 @@ def test_pivot_index_struct_14101() -> None: "d": [1, 1, 3], } ) - result = df.pivot(index="b", values="a", columns="c") + result = df.pivot(index="b", columns="c", values="a") expected = pl.DataFrame({"b": [{"a": 1}, {"a": 2}], "x": [1, None], "y": [2, 1]}) assert_frame_equal(result, expected) @@ -289,11 +318,11 @@ def test_pivot_floats() -> None: with pytest.raises(ComputeError, match="found multiple elements in the same group"): result = df.pivot( - values="price", index="weight", columns="quantity", aggregate_function=None + index="weight", columns="quantity", values="price", aggregate_function=None ) result = df.pivot( - values="price", index="weight", columns="quantity", aggregate_function="first" + index="weight", columns="quantity", values="price", aggregate_function="first" ) expected = { "weight": [1.0, 4.4, 8.8], @@ -304,9 +333,9 @@ def test_pivot_floats() -> None: assert result.to_dict(as_series=False) == expected result = df.pivot( - values="price", index=["article", "weight"], columns="quantity", + values="price", aggregate_function=None, ) expected = { @@ -329,7 +358,7 @@ def test_pivot_reinterpret_5907() -> None: ) result = df.pivot( - index=["A"], values=["C"], columns=["B"], aggregate_function=pl.element().sum() + index=["A"], columns=["B"], values=["C"], aggregate_function=pl.element().sum() ) expected = {"A": [3, -2], "x": [100, 50], "y": [500, -80]} assert result.to_dict(as_series=False) == expected @@ -389,7 +418,7 @@ def test_aggregate_function_default() -> None: with pytest.raises( pl.ComputeError, match="found multiple elements in the same group" ): - df.pivot(values="a", index="b", columns="c") + df.pivot(index="b", columns="c", values="a") def test_pivot_positional_args_deprecated() -> None: @@ -467,8 +496,29 @@ def test_multi_index_containing_struct() -> None: "d": [1, 1, 3], } ) - result = df.pivot(index=("b", "d"), values="a", columns="c") + result = df.pivot(index=("b", "d"), columns="c", values="a") expected = pl.DataFrame( {"b": [{"a": 1}, {"a": 2}], "d": [1, 3], "x": [1, None], "y": [2, 1]} ) assert_frame_equal(result, expected) + + +def test_list_pivot() -> None: + df = pl.DataFrame( + { + "a": [1, 2, 3, 1], + "b": [[1, 2], [3, 4], [5, 6], [1, 2]], + "c": ["x", "x", "y", "y"], + "d": [1, 2, 3, 4], + } + ) + assert df.pivot( + index=["a", "b"], + columns="c", + values="d", + ).to_dict(as_series=False) == { + "a": [1, 2, 3], + "b": [[1, 2], [3, 4], [5, 6]], + "x": [1, 2, None], + "y": [4, None, 3], + } diff --git a/py-polars/tests/unit/operations/test_rolling.py b/py-polars/tests/unit/operations/test_rolling.py index 592bb17673a14..4c7c2b7885603 100644 --- a/py-polars/tests/unit/operations/test_rolling.py +++ b/py-polars/tests/unit/operations/test_rolling.py @@ -9,17 +9,18 @@ from polars.testing import assert_frame_equal, assert_series_equal if TYPE_CHECKING: - from polars.type_aliases import ClosedInterval + from polars.type_aliases import ClosedInterval, PolarsIntegerType -def test_rolling_group_by_overlapping_groups() -> None: +@pytest.mark.parametrize("dtype", [pl.UInt32, pl.UInt64, pl.Int32, pl.Int64]) +def test_rolling_group_by_overlapping_groups(dtype: PolarsIntegerType) -> None: # this first aggregates overlapping groups so they cannot be naively flattened df = pl.DataFrame({"a": [41, 60, 37, 51, 52, 39, 40]}) assert_series_equal( ( df.with_row_index() - .with_columns(pl.col("index").cast(pl.Int32)) + .with_columns(pl.col("index").cast(dtype)) .rolling(index_column="index", period="5i") .agg( # trigger the apply on the expression engine @@ -31,12 +32,17 @@ def test_rolling_group_by_overlapping_groups() -> None: @pytest.mark.parametrize("input", [[pl.col("b").sum()], pl.col("b").sum()]) -def test_rolling_agg_input_types(input: Any) -> None: - df = pl.LazyFrame({"index_column": [0, 1, 2, 3], "b": [1, 3, 1, 2]}).set_sorted( - "index_column" - ) +@pytest.mark.parametrize("dtype", [pl.UInt32, pl.UInt64, pl.Int32, pl.Int64]) +def test_rolling_agg_input_types(input: Any, dtype: PolarsIntegerType) -> None: + df = pl.LazyFrame( + {"index_column": [0, 1, 2, 3], "b": [1, 3, 1, 2]}, + schema_overrides={"index_column": dtype}, + ).set_sorted("index_column") result = df.rolling(index_column="index_column", period="2i").agg(input) - expected = pl.LazyFrame({"index_column": [0, 1, 2, 3], "b": [1, 4, 4, 3]}) + expected = pl.LazyFrame( + {"index_column": [0, 1, 2, 3], "b": [1, 4, 4, 3]}, + schema_overrides={"index_column": dtype}, + ) assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/operations/test_slice.py b/py-polars/tests/unit/operations/test_slice.py index 0aa4fea662094..e163d36806ae3 100644 --- a/py-polars/tests/unit/operations/test_slice.py +++ b/py-polars/tests/unit/operations/test_slice.py @@ -168,3 +168,48 @@ def test_slice_pushdown_set_sorted() -> None: plan = ldf.explain() # check the set sorted is above slice assert plan.index("set_sorted") < plan.index("SLICE") + + +def test_slice_pushdown_literal_projection_14349() -> None: + lf = pl.select(a=pl.int_range(10)).lazy() + expect = pl.DataFrame({"a": [0, 1, 2, 3, 4], "b": [10, 11, 12, 13, 14]}) + + out = lf.with_columns(b=pl.int_range(10, 20, eager=True)).head(5).collect() + assert_frame_equal(expect, out) + + out = lf.select("a", b=pl.int_range(10, 20, eager=True)).head(5).collect() + assert_frame_equal(expect, out) + + assert pl.LazyFrame().select(x=1).head(0).collect().height == 0 + assert pl.LazyFrame().with_columns(x=1).head(0).collect().height == 0 + + q = lf.select(x=1).head(0) + assert q.collect().height == 0 + + # For select, slice pushdown should happen when at least 1 input column is selected + q = lf.select("a", x=1).head(0) + plan = q.explain() + assert plan.index("SELECT") < plan.index("SLICE") + assert q.collect().height == 0 + + # For with_columns, slice pushdown should happen if the input has at least 1 column + q = lf.with_columns(x=1).head(0) + plan = q.explain() + assert plan.index("WITH_COLUMNS") < plan.index("SLICE") + assert q.collect().height == 0 + + q = lf.with_columns(pl.col("a") + 1).head(0) + plan = q.explain() + assert plan.index("WITH_COLUMNS") < plan.index("SLICE") + assert q.collect().height == 0 + + # This does not project any of the original columns + q = lf.with_columns(a=1, b=2).head(0) + plan = q.explain() + assert plan.index("SLICE") < plan.index("WITH_COLUMNS") + assert q.collect().height == 0 + + q = lf.with_columns(b=1, c=2).head(0) + plan = q.explain() + assert plan.index("WITH_COLUMNS") < plan.index("SLICE") + assert q.collect().height == 0 diff --git a/py-polars/tests/unit/operations/test_sort.py b/py-polars/tests/unit/operations/test_sort.py index 4ea5c3e4059ec..aef3fcd8a5d9a 100644 --- a/py-polars/tests/unit/operations/test_sort.py +++ b/py-polars/tests/unit/operations/test_sort.py @@ -9,6 +9,14 @@ from polars.testing import assert_frame_equal, assert_series_equal +def is_sorted_any(s: pl.Series) -> bool: + return s.flags["SORTED_ASC"] or s.flags["SORTED_DESC"] + + +def is_not_sorted(s: pl.Series) -> bool: + return not is_sorted_any(s) + + def test_sort_dates_multiples() -> None: df = pl.DataFrame( [ @@ -789,3 +797,188 @@ def test_sort_with_null_12272() -> None: ) def test_sort_series_nulls_last(input: list[Any], expected: list[Any]) -> None: assert pl.Series(input).sort(nulls_last=True).to_list() == expected + + +def test_sorted_flag_14552() -> None: + a = pl.DataFrame({"a": [2, 1, 3]}) + + a = pl.concat([a, a], rechunk=False) + assert not a.join(a, on="a", how="left")["a"].flags["SORTED_ASC"] + + +def test_sorted_flag_concat_15072() -> None: + # Both all-null + a = pl.Series("x", [None, None], dtype=pl.Int8) + b = pl.Series("x", [None, None], dtype=pl.Int8) + assert pl.concat((a, b)).flags["SORTED_ASC"] + + # left all-null, right 0 < null_count < len + a = pl.Series("x", [None, None], dtype=pl.Int8) + b = pl.Series("x", [1, 2, 1, None], dtype=pl.Int8) + + out = pl.concat((a, b.sort())) + assert out.to_list() == [None, None, None, 1, 1, 2] + assert out.flags["SORTED_ASC"] + + out = pl.concat((a, b.sort(descending=True))) + assert out.to_list() == [None, None, None, 2, 1, 1] + assert out.flags["SORTED_DESC"] + + out = pl.concat((a, b.sort(nulls_last=True))) + assert out.to_list() == [None, None, 1, 1, 2, None] + assert is_not_sorted(out) + + out = pl.concat((a, b.sort(nulls_last=True, descending=True))) + assert out.to_list() == [None, None, 2, 1, 1, None] + assert is_not_sorted(out) + + # left 0 < null_count < len, right all-null + a = pl.Series("x", [1, 2, 1, None], dtype=pl.Int8) + b = pl.Series("x", [None, None], dtype=pl.Int8) + + out = pl.concat((a.sort(), b)) + assert out.to_list() == [None, 1, 1, 2, None, None] + assert is_not_sorted(out) + + out = pl.concat((a.sort(descending=True), b)) + assert out.to_list() == [None, 2, 1, 1, None, None] + assert is_not_sorted(out) + + out = pl.concat((a.sort(nulls_last=True), b)) + assert out.to_list() == [1, 1, 2, None, None, None] + assert out.flags["SORTED_ASC"] + + out = pl.concat((a.sort(nulls_last=True, descending=True), b)) + assert out.to_list() == [2, 1, 1, None, None, None] + assert out.flags["SORTED_DESC"] + + # both 0 < null_count < len + assert pl.concat( + ( + pl.Series([None, 1]).set_sorted(), + pl.Series([2]).set_sorted(), + ) + ).flags["SORTED_ASC"] + + assert is_not_sorted( + pl.concat( + ( + pl.Series([None, 1]).set_sorted(), + pl.Series([2, None]).set_sorted(), + ) + ) + ) + + assert pl.concat( + ( + pl.Series([None, 2]).set_sorted(descending=True), + pl.Series([1]).set_sorted(descending=True), + ) + ).flags["SORTED_DESC"] + + assert is_not_sorted( + pl.concat( + ( + pl.Series([None, 2]).set_sorted(descending=True), + pl.Series([1, None]).set_sorted(descending=True), + ) + ) + ) + + # Concat with empty series + s = pl.Series([None, 1]).set_sorted() + + out = pl.concat((s.clear(), s)) + assert_series_equal(out, s) + assert out.flags["SORTED_ASC"] + + out = pl.concat((s, s.clear())) + assert_series_equal(out, s) + assert out.flags["SORTED_ASC"] + + s = pl.Series([1, None]).set_sorted() + + out = pl.concat((s.clear(), s)) + assert_series_equal(out, s) + assert out.flags["SORTED_ASC"] + + out = pl.concat((s, s.clear())) + assert_series_equal(out, s) + assert out.flags["SORTED_ASC"] + + +@pytest.mark.parametrize("unit_descending", [True, False]) +def test_sorted_flag_concat_unit(unit_descending: bool) -> None: + unit = pl.Series([1]).set_sorted(descending=unit_descending) + + a = unit + b = pl.Series([2, 3]).set_sorted() + + out = pl.concat((a, b)) + assert out.to_list() == [1, 2, 3] + assert out.flags["SORTED_ASC"] + + out = pl.concat((b, a)) + assert out.to_list() == [2, 3, 1] + assert is_not_sorted(out) + + a = unit + b = pl.Series([3, 2]).set_sorted(descending=True) + + out = pl.concat((a, b)) + assert out.to_list() == [1, 3, 2] + assert is_not_sorted(out) + + out = pl.concat((b, a)) + assert out.to_list() == [3, 2, 1] + assert out.flags["SORTED_DESC"] + + # unit with nulls first + unit = pl.Series([None, 1]).set_sorted(descending=unit_descending) + + a = unit + b = pl.Series([2, 3]).set_sorted() + + out = pl.concat((a, b)) + assert out.to_list() == [None, 1, 2, 3] + assert out.flags["SORTED_ASC"] + + out = pl.concat((b, a)) + assert out.to_list() == [2, 3, None, 1] + assert is_not_sorted(out) + + a = unit + b = pl.Series([3, 2]).set_sorted(descending=True) + + out = pl.concat((a, b)) + assert out.to_list() == [None, 1, 3, 2] + assert is_not_sorted(out) + + out = pl.concat((b, a)) + assert out.to_list() == [3, 2, None, 1] + assert is_not_sorted(out) + + # unit with nulls last + unit = pl.Series([1, None]).set_sorted(descending=unit_descending) + + a = unit + b = pl.Series([2, 3]).set_sorted() + + out = pl.concat((a, b)) + assert out.to_list() == [1, None, 2, 3] + assert is_not_sorted(out) + + out = pl.concat((b, a)) + assert out.to_list() == [2, 3, 1, None] + assert is_not_sorted(out) + + a = unit + b = pl.Series([3, 2]).set_sorted(descending=True) + + out = pl.concat((a, b)) + assert out.to_list() == [1, None, 3, 2] + assert is_not_sorted(out) + + out = pl.concat((b, a)) + assert out.to_list() == [3, 2, 1, None] + assert out.flags["SORTED_DESC"] diff --git a/py-polars/tests/unit/operations/unique/test_approx_n_unique.py b/py-polars/tests/unit/operations/unique/test_approx_n_unique.py new file mode 100644 index 0000000000000..35b9c15983661 --- /dev/null +++ b/py-polars/tests/unit/operations/unique/test_approx_n_unique.py @@ -0,0 +1,20 @@ +import pytest + +import polars as pl +from polars.testing.asserts.frame import assert_frame_equal + + +def test_df_approx_n_unique_deprecated() -> None: + df = pl.DataFrame({"a": [1, 2, 2], "b": [2, 2, 2]}) + with pytest.deprecated_call(): + result = df.approx_n_unique() + expected = pl.DataFrame({"a": [2], "b": [1]}).cast(pl.UInt32) + assert_frame_equal(result, expected) + + +def test_lf_approx_n_unique_deprecated() -> None: + df = pl.LazyFrame({"a": [1, 2, 2], "b": [2, 2, 2]}) + with pytest.deprecated_call(): + result = df.approx_n_unique() + expected = pl.LazyFrame({"a": [2], "b": [1]}).cast(pl.UInt32) + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/operations/unique/test_unique.py b/py-polars/tests/unit/operations/unique/test_unique.py index 51300fefc7109..bb01912d5cc78 100644 --- a/py-polars/tests/unit/operations/unique/test_unique.py +++ b/py-polars/tests/unit/operations/unique/test_unique.py @@ -33,6 +33,20 @@ def test_unique_predicate_pd() -> None: expected = pl.DataFrame({"x": ["abc"], "y": ["xxx"], "z": [True]}) assert_frame_equal(result, expected) + # Issue #14595: filter should not naively be pushed past unique() + for maintain_order in (True, False): + for keep in ("first", "last", "any", "none"): + q = ( + lf.unique("x", maintain_order=maintain_order, keep=keep) # type: ignore[arg-type] + .filter(pl.col("x") == "abc") + .filter(pl.col("z")) + ) + plan = q.explain() + assert r'FILTER col("z")' in plan + # We can push filters if they only depend on the subset columns of unique() + assert r'SELECTION: "[(col(\"x\")) == (String(abc))]"' in plan + assert_frame_equal(q.collect(predicate_pushdown=False), q.collect()) + def test_unique_on_list_df() -> None: assert pl.DataFrame( diff --git a/py-polars/tests/unit/series/test_describe.py b/py-polars/tests/unit/series/test_describe.py index 15ed7bc84c54f..cdf1804232312 100644 --- a/py-polars/tests/unit/series/test_describe.py +++ b/py-polars/tests/unit/series/test_describe.py @@ -117,7 +117,7 @@ def test_series_describe_null() -> None: def test_series_describe_nested_list() -> None: s = pl.Series( values=[[10e10, 10e15], [10e12, 10e13], [10e10, 10e15]], - dtype=pl.List(pl.Int64), + dtype=pl.List(pl.Float64), ) result = s.describe() stats = { diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index 39acd2a03ee51..5d9178a582b57 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -11,6 +11,7 @@ import polars import polars as pl +from polars._utils.construction import iterable_to_pyseries from polars.datatypes import ( Date, Datetime, @@ -29,14 +30,13 @@ ) from polars.exceptions import ComputeError, PolarsInefficientMapWarning, ShapeError from polars.testing import assert_frame_equal, assert_series_equal -from polars.utils._construction import iterable_to_pyseries if TYPE_CHECKING: from zoneinfo import ZoneInfo from polars.type_aliases import EpochTimeUnit, PolarsDataType, TimeUnit else: - from polars.utils.convert import get_zoneinfo as ZoneInfo + from polars._utils.convert import string_to_zoneinfo as ZoneInfo def test_cum_agg() -> None: @@ -1044,14 +1044,6 @@ def test_map_elements() -> None: a.map_elements(lambda x: x) -def test_object() -> None: - vals = [[12], "foo", 9] - a = pl.Series("a", vals) - assert a.dtype == pl.Object - assert a.to_list() == vals - assert a[1] == "foo" - - def test_shape() -> None: s = pl.Series([1, 2, 3]) assert s.shape == (3,) @@ -1542,6 +1534,16 @@ def test_to_dummies() -> None: assert_frame_equal(result, expected) +def test_to_dummies_drop_first() -> None: + s = pl.Series("a", [1, 2, 3]) + result = s.to_dummies(drop_first=True) + expected = pl.DataFrame( + {"a_2": [0, 1, 0], "a_3": [0, 0, 1]}, + schema={"a_2": pl.UInt8, "a_3": pl.UInt8}, + ) + assert_frame_equal(result, expected) + + def test_chunk_lengths() -> None: s = pl.Series("a", [1, 2, 2, 3]) # this is a Series with one chunk, of length 4 @@ -1909,176 +1911,6 @@ def test_trigonometric_invalid_input() -> None: s.cosh() -def test_ewm_mean() -> None: - s = pl.Series([2, 5, 3]) - - expected = pl.Series([2.0, 4.0, 3.4285714285714284]) - assert_series_equal(s.ewm_mean(alpha=0.5, adjust=True, ignore_nulls=True), expected) - assert_series_equal( - s.ewm_mean(alpha=0.5, adjust=True, ignore_nulls=False), expected - ) - - expected = pl.Series([2.0, 3.8, 3.421053]) - assert_series_equal(s.ewm_mean(com=2.0, adjust=True, ignore_nulls=True), expected) - assert_series_equal(s.ewm_mean(com=2.0, adjust=True, ignore_nulls=False), expected) - - expected = pl.Series([2.0, 3.5, 3.25]) - assert_series_equal( - s.ewm_mean(alpha=0.5, adjust=False, ignore_nulls=True), expected - ) - assert_series_equal( - s.ewm_mean(alpha=0.5, adjust=False, ignore_nulls=False), expected - ) - - s = pl.Series([2, 3, 5, 7, 4]) - - expected = pl.Series([None, 2.666667, 4.0, 5.6, 4.774194]) - assert_series_equal( - s.ewm_mean(alpha=0.5, adjust=True, min_periods=2, ignore_nulls=True), expected - ) - assert_series_equal( - s.ewm_mean(alpha=0.5, adjust=True, min_periods=2, ignore_nulls=False), expected - ) - - expected = pl.Series([None, None, 4.0, 5.6, 4.774194]) - assert_series_equal( - s.ewm_mean(alpha=0.5, adjust=True, min_periods=3, ignore_nulls=True), expected - ) - assert_series_equal( - s.ewm_mean(alpha=0.5, adjust=True, min_periods=3, ignore_nulls=False), expected - ) - - s = pl.Series([None, 1.0, 5.0, 7.0, None, 2.0, 5.0, 4]) - - expected = pl.Series( - [ - None, - 1.0, - 3.6666666666666665, - 5.571428571428571, - 5.571428571428571, - 3.6666666666666665, - 4.354838709677419, - 4.174603174603175, - ], - ) - assert_series_equal(s.ewm_mean(alpha=0.5, adjust=True, ignore_nulls=True), expected) - expected = pl.Series( - [ - None, - 1.0, - 3.666666666666667, - 5.571428571428571, - 5.571428571428571, - 3.08695652173913, - 4.2, - 4.092436974789916, - ] - ) - assert_series_equal( - s.ewm_mean(alpha=0.5, adjust=True, ignore_nulls=False), expected - ) - - expected = pl.Series([None, 1.0, 3.0, 5.0, 5.0, 3.5, 4.25, 4.125]) - assert_series_equal( - s.ewm_mean(alpha=0.5, adjust=False, ignore_nulls=True), expected - ) - - expected = pl.Series([None, 1.0, 3.0, 5.0, 5.0, 3.0, 4.0, 4.0]) - assert_series_equal( - s.ewm_mean(alpha=0.5, adjust=False, ignore_nulls=False), expected - ) - - -def test_ewm_mean_leading_nulls() -> None: - for min_periods in [1, 2, 3]: - assert ( - pl.Series([1, 2, 3, 4]) - .ewm_mean(com=3, min_periods=min_periods) - .null_count() - == min_periods - 1 - ) - assert pl.Series([None, 1.0, 1.0, 1.0]).ewm_mean( - alpha=0.5, min_periods=1 - ).to_list() == [None, 1.0, 1.0, 1.0] - assert pl.Series([None, 1.0, 1.0, 1.0]).ewm_mean( - alpha=0.5, min_periods=2 - ).to_list() == [None, None, 1.0, 1.0] - - -def test_ewm_mean_min_periods() -> None: - series = pl.Series([1.0, None, None, None]) - - ewm_mean = series.ewm_mean(alpha=0.5, min_periods=1) - assert ewm_mean.to_list() == [1.0, 1.0, 1.0, 1.0] - ewm_mean = series.ewm_mean(alpha=0.5, min_periods=2) - assert ewm_mean.to_list() == [None, None, None, None] - - series = pl.Series([1.0, None, 2.0, None, 3.0]) - - ewm_mean = series.ewm_mean(alpha=0.5, min_periods=1) - assert_series_equal( - ewm_mean, - pl.Series( - [ - 1.0, - 1.0, - 1.6666666666666665, - 1.6666666666666665, - 2.4285714285714284, - ] - ), - ) - ewm_mean = series.ewm_mean(alpha=0.5, min_periods=2) - assert_series_equal( - ewm_mean, - pl.Series( - [ - None, - None, - 1.6666666666666665, - 1.6666666666666665, - 2.4285714285714284, - ] - ), - ) - - -def test_ewm_std_var() -> None: - series = pl.Series("a", [2, 5, 3]) - - var = series.ewm_var(alpha=0.5) - std = series.ewm_std(alpha=0.5) - - assert np.allclose(var, std**2, rtol=1e-16) - - -def test_ewm_param_validation() -> None: - s = pl.Series("values", range(10)) - - with pytest.raises(ValueError, match="mutually exclusive"): - s.ewm_std(com=0.5, alpha=0.5) - - with pytest.raises(ValueError, match="mutually exclusive"): - s.ewm_mean(span=1.5, half_life=0.75) - - with pytest.raises(ValueError, match="mutually exclusive"): - s.ewm_var(alpha=0.5, span=1.5) - - with pytest.raises(ValueError, match="require `com` >= 0"): - s.ewm_std(com=-0.5) - - with pytest.raises(ValueError, match="require `span` >= 1"): - s.ewm_mean(span=0.5) - - with pytest.raises(ValueError, match="require `half_life` > 0"): - s.ewm_var(half_life=0) - - for alpha in (-0.5, -0.0000001, 0.0, 1.0000001, 1.5): - with pytest.raises(ValueError, match="require 0 < `alpha` <= 1"): - s.ewm_std(alpha=alpha) - - def test_product() -> None: a = pl.Series("a", [1, 2, 3]) out = a.product() @@ -2279,6 +2111,11 @@ def test_min_max_agg_on_str() -> None: assert (s.min(), s.max()) == ("a", "x") +def test_min_max_full_nan_15058() -> None: + s = pl.Series([float("nan")] * 2) + assert all(x != x for x in [s.min(), s.max()]) + + def test_is_between() -> None: s = pl.Series("num", [1, 2, None, 4, 5]) assert s.is_between(2, 4).to_list() == [False, True, None, True, False] diff --git a/py-polars/tests/unit/sql/test_temporal.py b/py-polars/tests/unit/sql/test_temporal.py index 77bf04b44fa5a..9659c720ce842 100644 --- a/py-polars/tests/unit/sql/test_temporal.py +++ b/py-polars/tests/unit/sql/test_temporal.py @@ -154,7 +154,6 @@ def test_extract_century_millennium(dt: date, expected: list[int]) -> None: ("ms", [1704589323123, 1609324245987, 1136159999555]), ("us", [1704589323123456, 1609324245987654, 1136159999555555]), ("ns", [1704589323123456000, 1609324245987654000, 1136159999555555000]), - (None, [1704589323123456, 1609324245987654, 1136159999555555]), ], ) def test_timestamp_time_unit(unit: str | None, expected: list[int]) -> None: diff --git a/py-polars/tests/unit/streaming/test_streaming_categoricals.py b/py-polars/tests/unit/streaming/test_streaming_categoricals.py index 65dd967abb76e..b2eadda91deaa 100644 --- a/py-polars/tests/unit/streaming/test_streaming_categoricals.py +++ b/py-polars/tests/unit/streaming/test_streaming_categoricals.py @@ -16,3 +16,16 @@ def test_streaming_nested_categorical() -> None: "numbers": [1, 2], "cat": [["str"], ["bar"]], } + + +def test_streaming_cat_14933() -> None: + df1 = pl.LazyFrame({"a": pl.Series([0], dtype=pl.UInt32)}) + df2 = pl.LazyFrame( + [ + pl.Series("a", [0, 1], dtype=pl.UInt32), + pl.Series("l", [None, None], dtype=pl.Categorical(ordering="physical")), + ] + ) + assert df1.join(df2, on="a", how="left").collect(streaming=True).to_dict( + as_series=False + ) == {"a": [0], "l": [None]} diff --git a/py-polars/tests/unit/streaming/test_streaming_group_by.py b/py-polars/tests/unit/streaming/test_streaming_group_by.py index 7c13b7a05804c..6fa9f079f8d30 100644 --- a/py-polars/tests/unit/streaming/test_streaming_group_by.py +++ b/py-polars/tests/unit/streaming/test_streaming_group_by.py @@ -446,3 +446,10 @@ def test_group_by_multiple_keys_one_literal(streaming: bool) -> None: .to_dict(as_series=False) == expected ) + + +def test_streaming_group_null_count() -> None: + df = pl.DataFrame({"g": [1] * 6, "a": ["yes", None] * 3}).lazy() + assert df.group_by("g").agg(pl.col("a").count()).collect(streaming=True).to_dict( + as_series=False + ) == {"g": [1], "a": [3]} diff --git a/py-polars/tests/unit/streaming/test_streaming_join.py b/py-polars/tests/unit/streaming/test_streaming_join.py index b808478037417..bc09443a045c2 100644 --- a/py-polars/tests/unit/streaming/test_streaming_join.py +++ b/py-polars/tests/unit/streaming/test_streaming_join.py @@ -12,6 +12,34 @@ pytestmark = pytest.mark.xdist_group("streaming") +def test_streaming_outer_joins() -> None: + n = 100 + dfa = pl.DataFrame( + { + "a": np.random.randint(0, 40, n), + "idx": np.arange(0, n), + } + ) + + n = 100 + dfb = pl.DataFrame( + { + "a": np.random.randint(0, 40, n), + "idx": np.arange(0, n), + } + ) + + join_strategies: list[Literal["outer", "outer_coalesce"]] = [ + "outer", + "outer_coalesce", + ] + for how in join_strategies: + q = dfa.lazy().join(dfb.lazy(), on="a", how=how).sort(["idx"]) + a = q.collect(streaming=True) + b = q.collect(streaming=False) + assert_frame_equal(a, b) + + def test_streaming_joins() -> None: n = 100 dfa = pd.DataFrame( @@ -190,3 +218,18 @@ def test_join_null_matches_multiple_keys(streaming: bool) -> None: assert_frame_equal( df_a.join(df_b, on=["a", "idx"], how="outer").sort("a").collect(), expected ) + + +def test_streaming_join_and_union() -> None: + a = pl.LazyFrame({"a": [1, 2]}) + + b = pl.LazyFrame({"a": [1, 2, 4, 8]}) + + c = a.join(b, on="a") + # The join node latest ensures that the dispatcher + # needs to replace placeholders in unions. + q = pl.concat([a, b, c]) + + out = q.collect(streaming=True) + assert_frame_equal(out, q.collect(streaming=False)) + assert out.to_series().to_list() == [1, 2, 1, 2, 4, 8, 1, 2] diff --git a/py-polars/tests/unit/streaming/test_streaming_sort.py b/py-polars/tests/unit/streaming/test_streaming_sort.py index 3038d9dcbe14e..e61311459e0c9 100644 --- a/py-polars/tests/unit/streaming/test_streaming_sort.py +++ b/py-polars/tests/unit/streaming/test_streaming_sort.py @@ -93,11 +93,16 @@ def test_ooc_sort(tmp_path: Path, monkeypatch: Any) -> None: @pytest.mark.write_disk() -def test_streaming_sort(tmp_path: Path, monkeypatch: Any, capfd: Any) -> None: +@pytest.mark.parametrize("spill_source", [True, False]) +def test_streaming_sort( + tmp_path: Path, monkeypatch: Any, capfd: Any, spill_source: bool +) -> None: tmp_path.mkdir(exist_ok=True) monkeypatch.setenv("POLARS_TEMP_DIR", str(tmp_path)) monkeypatch.setenv("POLARS_FORCE_OOC", "1") monkeypatch.setenv("POLARS_VERBOSE", "1") + if spill_source: + monkeypatch.setenv("POLARS_SPILL_SORT_PARTITIONS", "1") # this creates a lot of duplicate partitions and triggers: #7568 assert ( pl.Series(np.random.randint(0, 100, 100)) @@ -109,13 +114,20 @@ def test_streaming_sort(tmp_path: Path, monkeypatch: Any, capfd: Any) -> None: ) (_, err) = capfd.readouterr() assert "df -> sort" in err + if spill_source: + assert "PARTITIONED FORCE SPILLED" in err @pytest.mark.write_disk() -def test_out_of_core_sort_9503(tmp_path: Path, monkeypatch: Any) -> None: +@pytest.mark.parametrize("spill_source", [True, False]) +def test_out_of_core_sort_9503( + tmp_path: Path, monkeypatch: Any, spill_source: bool +) -> None: tmp_path.mkdir(exist_ok=True) monkeypatch.setenv("POLARS_TEMP_DIR", str(tmp_path)) monkeypatch.setenv("POLARS_FORCE_OOC", "1") + if spill_source: + monkeypatch.setenv("POLARS_SPILL_SORT_PARTITIONS", "1") np.random.seed(0) num_rows = 100_000 @@ -246,3 +258,9 @@ def test_reverse_variable_sort_13573() -> None: assert df.sort("a", "b", descending=[True, False]).collect(streaming=True).to_dict( as_series=False ) == {"a": ["two", "three", "one"], "b": ["five", "six", "four"]} + + +def test_nulls_last_streaming_sort() -> None: + assert pl.LazyFrame({"x": [1, None]}).sort("x", nulls_last=True).collect( + streaming=True + ).to_dict(as_series=False) == {"x": [1, None]} diff --git a/py-polars/tests/unit/test_config.py b/py-polars/tests/unit/test_config.py index 17b58c7201c9d..f8efda2222b72 100644 --- a/py-polars/tests/unit/test_config.py +++ b/py-polars/tests/unit/test_config.py @@ -8,8 +8,8 @@ import polars as pl import polars.polars as plr +from polars._utils.unstable import issue_unstable_warning from polars.config import _POLARS_CFG_ENV_VARS -from polars.utils.unstable import issue_unstable_warning @pytest.fixture(autouse=True) @@ -87,30 +87,6 @@ def test_hide_header_elements() -> None: ) -def test_html_tables() -> None: - df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) - - # default: header contains names/dtypes - header = "abci64i64i64" - assert header in df._repr_html_() - - # validate that relevant config options are respected - with pl.Config(tbl_hide_column_names=True): - header = "i64i64i64" - assert header in df._repr_html_() - - with pl.Config(tbl_hide_column_data_types=True): - header = "abc" - assert header in df._repr_html_() - - with pl.Config( - tbl_hide_column_data_types=True, - tbl_hide_column_names=True, - ): - header = "" - assert header in df._repr_html_() - - def test_set_tbl_cols() -> None: df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) @@ -196,7 +172,7 @@ def test_set_tbl_rows() -> None: "╞═════╪═════╪═════╡\n" "│ 1 ┆ 5 ┆ 9 │\n" "│ 2 ┆ 6 ┆ 10 │\n" - "│ 3 ┆ 7 ┆ 11 │\n" + "│ … ┆ … ┆ … │\n" "│ 4 ┆ 8 ┆ 12 │\n" "└─────┴─────┴─────┘" ) @@ -205,8 +181,8 @@ def test_set_tbl_rows() -> None: "Series: 'ser' [i64]\n" "[\n" "\t1\n" + "\t2\n" "\t…\n" - "\t4\n" "\t5\n" "]" ) @@ -231,7 +207,7 @@ def test_set_tbl_rows() -> None: "[\n" "\t1\n" "\t2\n" - "\t3\n" + "\t…\n" "\t4\n" "\t5\n" "]" @@ -254,8 +230,8 @@ def test_set_tbl_rows() -> None: "│ i64 ┆ i64 ┆ i64 │\n" "╞═════╪═════╪═════╡\n" "│ 1 ┆ 6 ┆ 11 │\n" + "│ 2 ┆ 7 ┆ 12 │\n" "│ … ┆ … ┆ … │\n" - "│ 4 ┆ 9 ┆ 14 │\n" "│ 5 ┆ 10 ┆ 15 │\n" "└─────┴─────┴─────┘" ) diff --git a/py-polars/tests/unit/test_cpu_check.py b/py-polars/tests/unit/test_cpu_check.py new file mode 100644 index 0000000000000..23525f5126ddf --- /dev/null +++ b/py-polars/tests/unit/test_cpu_check.py @@ -0,0 +1,83 @@ +from unittest.mock import Mock + +import pytest + +from polars import _cpu_check +from polars._cpu_check import check_cpu_flags + + +@pytest.fixture() +def _feature_flags(monkeypatch: pytest.MonkeyPatch) -> None: + """Use the default set of feature flags.""" + feature_flags = "+sse3,+ssse3" + monkeypatch.setattr(_cpu_check, "_POLARS_FEATURE_FLAGS", feature_flags) + + +@pytest.mark.usefixtures("_feature_flags") +def test_check_cpu_flags( + monkeypatch: pytest.MonkeyPatch, recwarn: pytest.WarningsRecorder +) -> None: + cpu_flags = {"sse3": True, "ssse3": True} + mock_read_cpu_flags = Mock(return_value=cpu_flags) + monkeypatch.setattr(_cpu_check, "_read_cpu_flags", mock_read_cpu_flags) + + check_cpu_flags() + + assert len(recwarn) == 0 + + +@pytest.mark.usefixtures("_feature_flags") +def test_check_cpu_flags_missing_features(monkeypatch: pytest.MonkeyPatch) -> None: + cpu_flags = {"sse3": True, "ssse3": False} + mock_read_cpu_flags = Mock(return_value=cpu_flags) + monkeypatch.setattr(_cpu_check, "_read_cpu_flags", mock_read_cpu_flags) + + with pytest.warns(RuntimeWarning, match="Missing required CPU features") as w: + check_cpu_flags() + + assert "ssse3" in str(w[0].message) + + +def test_check_cpu_flags_unknown_flag( + monkeypatch: pytest.MonkeyPatch, +) -> None: + real_cpu_flags = {"sse3": True, "ssse3": False} + mock_read_cpu_flags = Mock(return_value=real_cpu_flags) + monkeypatch.setattr(_cpu_check, "_read_cpu_flags", mock_read_cpu_flags) + unknown_feature_flags = "+sse3,+ssse3,+HelloWorld!" + monkeypatch.setattr(_cpu_check, "_POLARS_FEATURE_FLAGS", unknown_feature_flags) + with pytest.raises(RuntimeError, match="unknown feature flag: 'HelloWorld!'"): + check_cpu_flags() + + +def test_check_cpu_flags_skipped_no_flags(monkeypatch: pytest.MonkeyPatch) -> None: + mock_read_cpu_flags = Mock() + monkeypatch.setattr(_cpu_check, "_read_cpu_flags", mock_read_cpu_flags) + + check_cpu_flags() + + assert mock_read_cpu_flags.call_count == 0 + + +@pytest.mark.usefixtures("_feature_flags") +def test_check_cpu_flags_skipped_lts_cpu(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(_cpu_check, "_POLARS_LTS_CPU", True) + + mock_read_cpu_flags = Mock() + monkeypatch.setattr(_cpu_check, "_read_cpu_flags", mock_read_cpu_flags) + + check_cpu_flags() + + assert mock_read_cpu_flags.call_count == 0 + + +@pytest.mark.usefixtures("_feature_flags") +def test_check_cpu_flags_skipped_env_var(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("POLARS_SKIP_CPU_CHECK", "1") + + mock_read_cpu_flags = Mock() + monkeypatch.setattr(_cpu_check, "_read_cpu_flags", mock_read_cpu_flags) + + check_cpu_flags() + + assert mock_read_cpu_flags.call_count == 0 diff --git a/py-polars/tests/unit/test_errors.py b/py-polars/tests/unit/test_errors.py index 06e8a906617db..43b2f6e646ea9 100644 --- a/py-polars/tests/unit/test_errors.py +++ b/py-polars/tests/unit/test_errors.py @@ -512,7 +512,8 @@ def test_err_on_invalid_time_zone_cast() -> None: def test_invalid_inner_type_cast_list() -> None: s = pl.Series([[-1, 1]]) with pytest.raises( - pl.ComputeError, match=r"cannot cast List inner type: 'Int64' to Categorical" + pl.InvalidOperationError, + match=r"cannot cast List inner type: 'Int64' to Categorical", ): s.cast(pl.List(pl.Categorical)) @@ -699,3 +700,17 @@ def test_error_lazyframe_not_repeating() -> None: match = "Error originated just after this operation:" assert str(exc_info).count(match) == 1 + + +def test_raise_not_found_in_simplify_14974() -> None: + df = pl.DataFrame() + with pytest.raises(pl.ColumnNotFoundError): + df.select(1 / (1 + pl.col("a"))) + + +def test_invalid_product_type() -> None: + with pytest.raises( + pl.InvalidOperationError, + match="`product` operation not supported for dtype", + ): + pl.Series([[1, 2, 3]]).product() diff --git a/py-polars/tests/unit/test_format.py b/py-polars/tests/unit/test_format.py index f12524e3a27f6..c403e2af7de40 100644 --- a/py-polars/tests/unit/test_format.py +++ b/py-polars/tests/unit/test_format.py @@ -61,22 +61,7 @@ def _environ() -> Iterator[None]: 2 3 4 - 5 - 6 - 7 - 8 - 9 - 10 - 11 … - 87 - 88 - 89 - 90 - 91 - 92 - 93 - 94 95 96 97 diff --git a/py-polars/tests/unit/test_lazy.py b/py-polars/tests/unit/test_lazy.py index dcc387d060fd4..3832942501093 100644 --- a/py-polars/tests/unit/test_lazy.py +++ b/py-polars/tests/unit/test_lazy.py @@ -738,8 +738,8 @@ def test_rolling(fruits_cars: pl.DataFrame) -> None: ] ).collect() - assert cast(float, out_single_val_variance[0, "std"]) == 0.0 - assert cast(float, out_single_val_variance[0, "var"]) == 0.0 + assert cast(float, out_single_val_variance[0, "std"]) is None + assert cast(float, out_single_val_variance[0, "var"]) is None def test_arr_namespace(fruits_cars: pl.DataFrame) -> None: diff --git a/py-polars/tests/unit/test_plugins.py b/py-polars/tests/unit/test_plugins.py new file mode 100644 index 0000000000000..b983c9a9044f4 --- /dev/null +++ b/py-polars/tests/unit/test_plugins.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import pytest + +import polars as pl +from polars.plugins import ( + _is_dynamic_lib, + _resolve_plugin_path, + _serialize_kwargs, + register_plugin_function, +) + + +@pytest.mark.write_disk() +def test_register_plugin_function_invalid_plugin_path(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + plugin_path = tmp_path / "lib.so" + plugin_path.touch() + + expr = register_plugin_function( + plugin_path=plugin_path, function_name="hello", args=5 + ) + + with pytest.raises(pl.ComputeError, match="error loading dynamic library"): + pl.select(expr) + + +@pytest.mark.parametrize( + ("input", "expected"), + [ + (None, b""), + ({}, b""), + ( + {"hi": 0}, + b"\x80\x05\x95\x0b\x00\x00\x00\x00\x00\x00\x00}\x94\x8c\x02hi\x94K\x00s.", + ), + ], +) +def test_serialize_kwargs(input: dict[str, Any] | None, expected: bytes) -> None: + assert _serialize_kwargs(input) == expected + + +@pytest.mark.write_disk() +def test_resolve_plugin_path(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + (tmp_path / "lib1.so").touch() + (tmp_path / "__init__.py").touch() + + expected = tmp_path / "lib1.so" + + result = _resolve_plugin_path(tmp_path) + assert result == expected + result = _resolve_plugin_path(tmp_path / "lib1.so") + assert result == expected + result = _resolve_plugin_path(str(tmp_path)) + assert result == expected + result = _resolve_plugin_path(str(tmp_path / "lib1.so")) + assert result == expected + + +@pytest.mark.write_disk() +def test_resolve_plugin_path_raises(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + (tmp_path / "__init__.py").touch() + + with pytest.raises(FileNotFoundError, match="no dynamic library found"): + _resolve_plugin_path(tmp_path) + + +@pytest.mark.write_disk() +@pytest.mark.parametrize( + ("path", "expected"), + [ + (Path("lib.so"), True), + (Path("lib.pyd"), True), + (Path("lib.dll"), True), + (Path("lib.py"), False), + ], +) +def test_is_dynamic_lib(path: Path, expected: bool, tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + full_path = tmp_path / path + full_path.touch() + assert _is_dynamic_lib(full_path) is expected + + +@pytest.mark.write_disk() +def test_is_dynamic_lib_dir(tmp_path: Path) -> None: + path = Path("lib.so") + full_path = tmp_path / path + + full_path.mkdir(exist_ok=True) + (full_path / "hello.txt").touch() + + assert _is_dynamic_lib(full_path) is False diff --git a/py-polars/tests/unit/test_predicates.py b/py-polars/tests/unit/test_predicates.py index 0b14f4ff2a190..56db8b65911ef 100644 --- a/py-polars/tests/unit/test_predicates.py +++ b/py-polars/tests/unit/test_predicates.py @@ -491,3 +491,16 @@ def test_filter_eq_missing_13861() -> None: pl.col("a").is_null(), ): assert lf.collect().filter(filter_expr).rows() == [(None, "yy")] + + +@pytest.mark.parametrize("how", ["left", "inner"]) +def test_predicate_pushdown_block_join(how: Any) -> None: + q = ( + pl.LazyFrame({"a": [1]}) + .join( + pl.LazyFrame({"a": [2], "b": [1]}), left_on=["a"], right_on=["b"], how=how + ) + .filter(pl.col("a") == 1) + ) + + assert_frame_equal(q.collect(no_optimization=True), q.collect()) diff --git a/py-polars/tests/unit/test_serde.py b/py-polars/tests/unit/test_serde.py index 6fb79230e89b0..8fe6b1131174e 100644 --- a/py-polars/tests/unit/test_serde.py +++ b/py-polars/tests/unit/test_serde.py @@ -33,13 +33,10 @@ def test_lazyframe_serde() -> None: def test_serde_time_unit() -> None: - assert pickle.loads( - pickle.dumps( - pl.Series( - [datetime(2022, 1, 1) + timedelta(days=1) for _ in range(3)] - ).cast(pl.Datetime("ns")) - ) - ).dtype == pl.Datetime("ns") + values = [datetime(2022, 1, 1) + timedelta(days=1) for _ in range(3)] + s = pl.Series(values).cast(pl.Datetime("ns")) + result = pickle.loads(pickle.dumps(s)) + assert result.dtype == pl.Datetime("ns") def test_serde_duration() -> None: @@ -103,14 +100,6 @@ def test_deser_empty_list() -> None: assert s.to_list() == [[[42.0]], []] -def test_expression_json() -> None: - e = pl.col("foo").sum().over("bar") - json = e.meta.write_json() - - round_tripped = pl.Expr.from_json(json) - assert round_tripped.meta == e - - def times2(x: pl.Series) -> pl.Series: return x * 2 @@ -199,15 +188,70 @@ def test_serde_array_dtype() -> None: assert_series_equal(pickle.loads(pickle.dumps(s)), s) nested_s = pl.Series( - [[[1, 2, 3], [4, None]], None, [[None, None, 2]]], + [[[1, 2, 3], [4, None, 5]], None, [[None, None, 2]]], dtype=pl.List(pl.Array(pl.Int32(), width=3)), ) assert_series_equal(pickle.loads(pickle.dumps(nested_s)), nested_s) +def test_expr_serialization_roundtrip() -> None: + expr = pl.col("foo").sum().over("bar") + json = expr.meta.serialize() + round_tripped = pl.Expr.deserialize(io.StringIO(json)) + assert round_tripped.meta == expr + + +def test_expr_deserialize_file_not_found() -> None: + with pytest.raises(FileNotFoundError): + pl.Expr.deserialize("abcdef") + + +def test_expr_deserialize_invalid_json() -> None: + with pytest.raises( + pl.ComputeError, match="could not deserialize input into an expression" + ): + pl.Expr.deserialize(io.StringIO("abcdef")) + + +def test_expr_write_json_from_json_deprecated() -> None: + expr = pl.col("foo").sum().over("bar") + + with pytest.deprecated_call(): + json = expr.meta.write_json() + + with pytest.deprecated_call(): + round_tripped = pl.Expr.from_json(json) + + assert round_tripped.meta == expr + + def test_expression_json_13991() -> None: - e = pl.col("foo").cast(pl.Decimal) - json = e.meta.write_json() + expr = pl.col("foo").cast(pl.Decimal) + json = expr.meta.serialize() + + round_tripped = pl.Expr.deserialize(io.StringIO(json)) + assert round_tripped.meta == expr + + +def test_serde_data_type_class() -> None: + dtype = pl.Datetime + serialized = pickle.dumps(dtype) + deserialized = pickle.loads(serialized) + assert deserialized == dtype + assert isinstance(deserialized, type) + + +def test_serde_data_type_instantiated() -> None: + dtype = pl.Int8() + serialized = pickle.dumps(dtype) + deserialized = pickle.loads(serialized) + assert deserialized == dtype + assert isinstance(deserialized, pl.DataType) + - round_tripped = pl.Expr.from_json(json) - assert round_tripped.meta == e +def test_serde_data_type_instantiated_with_attributes() -> None: + dtype = pl.Enum(["a", "b"]) + serialized = pickle.dumps(dtype) + deserialized = pickle.loads(serialized) + assert deserialized == dtype + assert isinstance(deserialized, pl.DataType) diff --git a/py-polars/tests/unit/testing/test_assert_frame_equal.py b/py-polars/tests/unit/testing/test_assert_frame_equal.py index bf8727c178a1b..5b2e9f92092ad 100644 --- a/py-polars/tests/unit/testing/test_assert_frame_equal.py +++ b/py-polars/tests/unit/testing/test_assert_frame_equal.py @@ -250,8 +250,8 @@ def test_compare_frame_equal_nested_nans() -> None: { "id": 2, "struct": [ - {"x": "text", "y": [nan, 1], "z": ["!"]}, - {"x": "text", "y": [nan, 1], "z": ["?"]}, + {"x": "text", "y": [nan, 1.0], "z": ["!"]}, + {"x": "text", "y": [nan, 1.0], "z": ["?"]}, ], }, ] @@ -342,14 +342,18 @@ def test_assert_frame_equal_ignore_row_order() -> None: assert_frame_equal(df1, df3, check_row_order=False, check_column_order=False) + class Foo: + def __init__(self) -> None: + pass + # note: not all column types support sorting with pytest.raises( InvalidAssert, match="cannot set `check_row_order=False`.*unsortable columns", ): assert_frame_equal( - left=pl.DataFrame({"a": [[1, 2], [3, 4]], "b": [3, 4]}), - right=pl.DataFrame({"a": [[3, 4], [1, 2]], "b": [4, 3]}), + left=pl.DataFrame({"a": [Foo(), Foo()], "b": [3, 4]}), + right=pl.DataFrame({"a": [Foo(), Foo()], "b": [4, 3]}), check_row_order=False, ) diff --git a/py-polars/tests/unit/utils/test_deprecation.py b/py-polars/tests/unit/utils/test_deprecation.py index fee0a18724b6d..5c067404afd0d 100644 --- a/py-polars/tests/unit/utils/test_deprecation.py +++ b/py-polars/tests/unit/utils/test_deprecation.py @@ -5,7 +5,7 @@ import pytest -from polars.utils.deprecation import ( +from polars._utils.deprecation import ( deprecate_function, deprecate_nonkeyword_arguments, deprecate_renamed_function, @@ -21,8 +21,7 @@ def test_issue_deprecation_warning() -> None: def test_deprecate_function() -> None: @deprecate_function("This is deprecated.", version="1.0.0") - def hello() -> None: - ... + def hello() -> None: ... with pytest.deprecated_call(): hello() @@ -30,8 +29,7 @@ def hello() -> None: def test_deprecate_renamed_function() -> None: @deprecate_renamed_function("new_hello", version="1.0.0") - def hello() -> None: - ... + def hello() -> None: ... with pytest.deprecated_call(match="new_hello"): hello() @@ -40,8 +38,7 @@ def hello() -> None: def test_deprecate_renamed_parameter(recwarn: Any) -> None: @deprecate_renamed_parameter("foo", "oof", version="1.0.0") @deprecate_renamed_parameter("bar", "rab", version="2.0.0") - def hello(oof: str, rab: str, ham: str) -> None: - ... + def hello(oof: str, rab: str, ham: str) -> None: ... hello(foo="x", bar="y", ham="z") # type: ignore[call-arg] @@ -52,10 +49,9 @@ def hello(oof: str, rab: str, ham: str) -> None: class Foo: # noqa: D101 @deprecate_nonkeyword_arguments(allowed_args=["self", "baz"], version="0.1.2") - def bar( # noqa: D102 + def bar( self, baz: str, ham: str | None = None, foobar: str | None = None - ) -> None: - ... + ) -> None: ... def test_deprecate_nonkeyword_arguments_method_signature() -> None: diff --git a/py-polars/tests/unit/utils/test_parse_expr_input.py b/py-polars/tests/unit/utils/test_parse_expr_input.py index 8c58c1307688e..a17debfc94cb8 100644 --- a/py-polars/tests/unit/utils/test_parse_expr_input.py +++ b/py-polars/tests/unit/utils/test_parse_expr_input.py @@ -6,9 +6,9 @@ import pytest import polars as pl +from polars._utils.parse_expr_input import parse_as_expression +from polars._utils.wrap import wrap_expr from polars.testing import assert_frame_equal -from polars.utils._parse_expr_input import parse_as_expression -from polars.utils._wrap import wrap_expr def assert_expr_equal(result: pl.Expr, expected: pl.Expr) -> None: diff --git a/py-polars/tests/unit/utils/test_unstable.py b/py-polars/tests/unit/utils/test_unstable.py index ea9e5d594c9fe..8a0c738abe202 100644 --- a/py-polars/tests/unit/utils/test_unstable.py +++ b/py-polars/tests/unit/utils/test_unstable.py @@ -3,7 +3,7 @@ import pytest import polars as pl -from polars.utils.unstable import issue_unstable_warning, unstable +from polars._utils.unstable import issue_unstable_warning, unstable def test_issue_unstable_warning(monkeypatch: pytest.MonkeyPatch) -> None: @@ -37,8 +37,7 @@ def test_unstable_decorator(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("POLARS_WARN_UNSTABLE", "1") @unstable() - def hello() -> None: - ... + def hello() -> None: ... msg = "`hello` is considered unstable." with pytest.warns(pl.UnstableWarning, match=msg): @@ -47,8 +46,7 @@ def hello() -> None: def test_unstable_decorator_setting_disabled(recwarn: pytest.WarningsRecorder) -> None: @unstable() - def hello() -> None: - ... + def hello() -> None: ... hello() assert len(recwarn) == 0 diff --git a/py-polars/tests/unit/utils/test_utils.py b/py-polars/tests/unit/utils/test_utils.py index fc84cc3d59f8c..d81540efc10d3 100644 --- a/py-polars/tests/unit/utils/test_utils.py +++ b/py-polars/tests/unit/utils/test_utils.py @@ -7,15 +7,14 @@ import pytest import polars as pl -from polars.io._utils import _looks_like_url -from polars.utils.convert import ( - _date_to_pl_date, - _datetime_to_pl_timestamp, - _time_to_pl_time, - _timedelta_to_pl_duration, - _timedelta_to_pl_timedelta, +from polars._utils.convert import ( + date_to_int, + datetime_to_int, + parse_as_duration_string, + time_to_int, + timedelta_to_int, ) -from polars.utils.various import ( +from polars._utils.various import ( _in_notebook, is_bool_sequence, is_int_sequence, @@ -24,74 +23,137 @@ parse_percentiles, parse_version, ) +from polars.io._utils import _looks_like_url if TYPE_CHECKING: + from zoneinfo import ZoneInfo + from polars.type_aliases import TimeUnit +else: + from polars._utils.convert import string_to_zoneinfo as ZoneInfo @pytest.mark.parametrize( - ("dt", "time_unit", "expected"), + ("td", "expected"), [ - (datetime(2121, 1, 1), "ns", 4765132800000000000), - (datetime(2121, 1, 1), "us", 4765132800000000), - (datetime(2121, 1, 1), "ms", 4765132800000), + (timedelta(), ""), + (timedelta(days=1), "1d"), + (timedelta(days=-1), "-1d"), + (timedelta(seconds=1), "1s"), + (timedelta(seconds=-1), "-1s"), + (timedelta(microseconds=1), "1us"), + (timedelta(microseconds=-1), "-1us"), + (timedelta(days=1, seconds=1), "1d1s"), + (timedelta(minutes=-1, seconds=1), "-59s"), + (timedelta(days=-1, seconds=-1), "-1d1s"), + (timedelta(days=1, microseconds=1), "1d1us"), + (timedelta(days=-1, microseconds=-1), "-1d1us"), + (None, None), + ("1d2s", "1d2s"), ], ) -def test_datetime_to_pl_timestamp( - dt: datetime, time_unit: TimeUnit, expected: int +def test_parse_as_duration_string( + td: timedelta | str | None, expected: str | None ) -> None: - out = _datetime_to_pl_timestamp(dt, time_unit) - assert out == expected + assert parse_as_duration_string(td) == expected + + +@pytest.mark.parametrize( + ("d", "expected"), + [ + (date(1999, 9, 9), 10_843), + (date(1969, 12, 31), -1), + (date.min, -719_162), + (date.max, 2_932_896), + ], +) +def test_date_to_int(d: date, expected: int) -> None: + assert date_to_int(d) == expected @pytest.mark.parametrize( ("t", "expected"), [ - (time(0, 0, 0), 0), (time(0, 0, 1), 1_000_000_000), (time(20, 52, 10), 75_130_000_000_000), (time(20, 52, 10, 200), 75_130_000_200_000), + (time.min, 0), + (time.max, 86_399_999_999_000), + (time(12, 0, tzinfo=None), 43_200_000_000_000), + (time(12, 0, tzinfo=ZoneInfo("UTC")), 43_200_000_000_000), + (time(12, 0, tzinfo=ZoneInfo("Asia/Shanghai")), 43_200_000_000_000), + (time(12, 0, tzinfo=ZoneInfo("US/Central")), 43_200_000_000_000), ], ) -def test_time_to_pl_time(t: time, expected: int) -> None: - assert _time_to_pl_time(t) == expected +def test_time_to_int(t: time, expected: int) -> None: + assert time_to_int(t) == expected -def test_date_to_pl_date() -> None: - d = date(1999, 9, 9) - out = _date_to_pl_date(d) - assert out == 10843 +@pytest.mark.parametrize( + "tzinfo", [None, ZoneInfo("UTC"), ZoneInfo("Asia/Shanghai"), ZoneInfo("US/Central")] +) +def test_time_to_int_with_time_zone(tzinfo: Any) -> None: + t = time(12, 0, tzinfo=tzinfo) + assert time_to_int(t) == 43_200_000_000_000 -def test_timedelta_to_pl_timedelta() -> None: - out = _timedelta_to_pl_timedelta(timedelta(days=1), "ns") - assert out == 86_400_000_000_000 - out = _timedelta_to_pl_timedelta(timedelta(days=1), "us") - assert out == 86_400_000_000 - out = _timedelta_to_pl_timedelta(timedelta(days=1), "ms") - assert out == 86_400_000 - out = _timedelta_to_pl_timedelta(timedelta(days=1), time_unit=None) - assert out == 86_400_000_000 +@pytest.mark.parametrize( + ("dt", "time_unit", "expected"), + [ + (datetime(2121, 1, 1), "ns", 4_765_132_800_000_000_000), + (datetime(2121, 1, 1), "us", 4_765_132_800_000_000), + (datetime(2121, 1, 1), "ms", 4_765_132_800_000), + (datetime(1969, 12, 31, 23, 59, 59, 999999), "us", -1), + (datetime(1969, 12, 30, 23, 59, 59, 999999), "us", -86_400_000_001), + (datetime.min, "ns", -62_135_596_800_000_000_000), + (datetime.max, "ns", 253_402_300_799_999_999_000), + (datetime.min, "ms", -62_135_596_800_000), + (datetime.max, "ms", 253_402_300_799_999), + ], +) +def test_datetime_to_int(dt: datetime, time_unit: TimeUnit, expected: int) -> None: + assert datetime_to_int(dt, time_unit) == expected @pytest.mark.parametrize( - ("td", "expected"), + ("dt", "expected"), [ - (timedelta(days=1), "1d"), - (timedelta(days=-1), "-1d"), - (timedelta(seconds=1), "1s"), - (timedelta(seconds=-1), "-1s"), - (timedelta(microseconds=1), "1us"), - (timedelta(microseconds=-1), "-1us"), - (timedelta(days=1, seconds=1), "1d1s"), - (timedelta(days=-1, seconds=-1), "-1d1s"), - (timedelta(days=1, microseconds=1), "1d1us"), - (timedelta(days=-1, microseconds=-1), "-1d1us"), + ( + datetime(2000, 1, 1, 12, 0, tzinfo=None), + 946_728_000_000_000, + ), + ( + datetime(2000, 1, 1, 12, 0, tzinfo=ZoneInfo("UTC")), + 946_728_000_000_000, + ), + ( + datetime(2000, 1, 1, 12, 0, tzinfo=ZoneInfo("Asia/Shanghai")), + 946_699_200_000_000, + ), + ( + datetime(2000, 1, 1, 12, 0, tzinfo=ZoneInfo("US/Central")), + 946_749_600_000_000, + ), + ], +) +def test_datetime_to_int_with_time_zone(dt: datetime, expected: int) -> None: + assert datetime_to_int(dt, "us") == expected + + +@pytest.mark.parametrize( + ("td", "time_unit", "expected"), + [ + (timedelta(days=1), "ns", 86_400_000_000_000), + (timedelta(days=1), "us", 86_400_000_000), + (timedelta(days=1), "ms", 86_400_000), + (timedelta.min, "ns", -86_399_999_913_600_000_000_000), + (timedelta.max, "ns", 86_399_999_999_999_999_999_000), + (timedelta.min, "ms", -86_399_999_913_600_000), + (timedelta.max, "ms", 86_399_999_999_999_999), ], ) -def test_timedelta_to_pl_duration(td: timedelta, expected: str) -> None: - out = _timedelta_to_pl_duration(td) - assert out == expected +def test_timedelta_to_int(td: timedelta, time_unit: TimeUnit, expected: int) -> None: + assert timedelta_to_int(td, time_unit) == expected def test_estimated_size() -> None: diff --git a/rust-toolchain.toml b/rust-toolchain.toml index f1b98f9ea7120..5f75e2c9af812 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,2 +1,2 @@ [toolchain] -channel = "nightly-2024-01-24" +channel = "nightly-2024-02-23"